本文首發於:行者AI
眾所周知,在基於價值學習的強化學習算法中,如DQN,函數近似誤差是導致Q值高估和次優策略的原因。我們表明這個問題依然在AC框架中存在,並提出了新的機制去最小化它對演員(策略函數)和評論家(估值函數)的影響。我們的算法建立在雙Q學習的基礎上,通過選取兩個估值函數中的較小值,從而限制它對Q值的過高估計。(出自TD3論文摘要)
1. 什么是TD3
TD3是Twin Delayed Deep Deterministic policy gradient algorithm的全稱。TD3全稱中Deep Deterministic policy gradient algorithm就是DDPG的全稱。那么DDPG和TD3有何淵源呢?其實簡單的說,TD3是DDPG的一個優化版本。
1.1 TD3為什么被提出
在強化學習中,對於離散化的動作的學習,都是以DQN為基礎的,DQN則是通過的\(argMaxQ_{table}\)的方式去選擇動作,往往都會過大的估計價值函數,從而造成誤差。在連續的動作控制的AC框架中,如果每一步都采用這種方式去估計,導致誤差一步一步的累加,導致不能找到最優策略,最終使算法不能得到收斂。
1.2 TD3在DDPG的基礎上都做了些什么
-
使用兩個Critic網絡。使用兩個網絡對動作價值函數進行估計,(這Double DQN 的思想差不多)。在訓練的時候選擇\(min(Q^{\theta1}(s,a),Q^{\theta2}(s,a))\)作為估計值。
-
使用軟更新的方式 。不再采用直接復制,而是使用 \(\theta = \tau\theta^′ + (1 - \tau)\theta\)的方式更新網絡參數。
-
使用策略噪音。使用Epsilon-Greedy在探索的時候使用了探索噪音。(還是用了策略噪聲,在更新參數的時候,用於平滑策略期望)
-
使用延遲學習。Critic網絡更新的頻率要比Actor網絡更新的頻率要大。
-
使用梯度截取。將Actor的參數更新的梯度截取到某個范圍內。
2. TD3算法思路
TD3算法的大致思路,首先初始化3個網絡,分別為\(Q_{\theta1},Q_{\theta2},\pi_\phi\) ,參數為\(\theta_1,\theta_2,\phi\),在初始化3個Target網絡,分別將開始初始化的3個網絡參數分別對應的復制給target網絡。\(\theta{_1^′}\leftarrow\theta_1,\theta{_2^′}\leftarrow\theta_2,\phi_′\leftarrow\phi\) 。初始化Replay Buffer \(\beta\) 。然后通過循環迭代,一次次找到最優策略。每次迭代,在選擇action的值的時候加入了噪音,使\(a~\pi_\phi(s) + \epsilon\),\(\epsilon \sim N(0,\sigma)\),然后將\((s,a,r,s^′)\)放入\(\beta\),當\(\beta\)達到一定的值時候。然后隨機從\(\beta\)中Sample出Mini-Batch個數據,通過\(\tilde{a} \sim\pi_{\phi^′}(s^′) + \epsilon\),\(\epsilon \sim clip(N(0,\tilde\sigma),-c,c)\),計算出\(s^′\)狀態下對應的Action的值\(\tilde a\),通過\(s^′,\tilde a\),計算出\(targetQ1,targetQ2\),獲取\(min(targetQ1,targetQ)\),為\(s^′\)的\(targetQ\)值。
通過貝爾曼方程計算\(s\)的\(targetQ\)值,通過兩個Current網絡根據\(s,a\)分別計算出當前的\(Q\)值,在將兩個當前網絡的\(Q\)值和\(targetQ\)值通過MSE計算Loss,更新參數。Critic網絡更新之后,Actor網絡則采用了延時更新,(一般采用Critic更新2次,Actor更新1次)。通過梯度上升的方式更新Actor網絡。通過軟更新的方式,更新target網絡。
-
為什么在更新Critic網絡時,在計算Action值的時候加入噪音,是為了平滑前面加入的噪音。
-
貝爾曼方程:針對一個連續的MRP(Markov Reward Process)的過程(連續的狀態獎勵過程),狀態\(s\)轉移到下一個狀態\(s^′\) 的概率的固定的,與前面的幾輪狀態無關。其中,\(v\)表示一個對當前狀態state 進行估值的函數。\(\gamma\)一般為趨近於1,但是小於1。
3. 代碼實現
代碼主要是根據DDPG的代碼以及TD3的論文復現的,使用的是Pytorch1.7實現的。
3.1 搭建網絡結構
Q1網絡結構主要是用於更新Actor網絡
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.f1 = nn.Linear(state_dim, 256)
self.f2 = nn.Linear(256, 128)
self.f3 = nn.Linear(128, action_dim)
self.max_action = max_action
def forward(self,x):
x = self.f1(x)
x = F.relu(x)
x = self.f2(x)
x = F.relu(x)
x = self.f3(x)
return torch.tanh(x) * self.max_action
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic,self).__init__()
self.f11 = nn.Linear(state_dim+action_dim, 256)
self.f12 = nn.Linear(256, 128)
self.f13 = nn.Linear(128, 1)
self.f21 = nn.Linear(state_dim + action_dim, 256)
self.f22 = nn.Linear(256, 128)
self.f23 = nn.Linear(128, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
x = self.f11(sa)
x = F.relu(x)
x = self.f12(x)
x = F.relu(x)
Q1 = self.f13(x)
x = self.f21(sa)
x = F.relu(x)
x = self.f22(x)
x = F.relu(x)
Q2 = self.f23(x)
return Q1, Q2
3.2 定義網絡
self.actor = Actor(self.state_dim, self.action_dim, self.max_action)
self.target_actor = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
#定義critic網絡
self.critic = Critic(self.state_dim, self.action_dim)
self.target_critic = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
3.3 更新網絡
更新網絡采用軟更新,延遲更新等方式
def learn(self):
self.total_it += 1
data = self.buffer.smaple(size=128)
state, action, done, state_next, reward = data
with torch.no_grad:
noise = (torch.rand_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_action = (self.target_actor(state_next) + noise).clamp(-self.max_action, self.max_action)
target_Q1,target_Q2 = self.target_critic(state_next, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + done * self.discount * target_Q
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
critic_loss.backward()
self.critic_optimizer.step()
if self.total_it % self.policy_freq == 0:
q1,q2 = self.critic(state, self.actor(state))
actor_loss = -torch.min(q1, q2).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
4. 總結
TD3是DDPG的一個升級版,在解決很多的問題上,效果要比DDPG的效果好的多,無論是訓練速度,還是結果都有顯著的提高。
5. 資料
PS:更多技術干貨,快關注【公眾號 | xingzhe_ai】,與行者一起討論吧!