本文首發於:行者AI
2016年Google DeepMind提出了Dueling Network Architectures for Deep Reinforcement Learning,采用優勢函數advantage function,使Dueling DQN在只收集一個離散動作的數據后,能夠更加准確的去估算Q值,選擇更加合適的動作。Double DQN,通過目標Q值選擇的動作來選擇目標Q值,從而消除Q值過高估計的問題。D3QN(Dueling Double DQN)則是結合了Dueling DQN和Double DQN的優點。
1. Dueling DQN
決斗(Dueling)DQN,網絡結構如圖1所示,圖1中上面的網絡為傳統的DQN網絡。圖1中下面的網絡則是Dueling DQN網絡。Dueling DQN網絡與傳統的DQN網絡結構的區別在於Dueling DQN的網絡中間隱藏層分別輸出value函數\(V\)和advantage function優勢函數\(A\),通過:\(Q(s,a;\theta,\alpha,\beta) = V(s;\theta,\beta) +\)\((A(s,a;\theta,\alpha)\) - \(1 \over |A|\) \(\sum_{} A(s,a′;\theta,\alpha) )\)計算出各個動作對應的Q值。
圖1. Dueling DQN網絡結構
2. D3QN
Double DQN只在DQN的基礎上有一點改動,就不在這兒介紹了,如果對DQN還不了解的話,可以戳這里。
2.1 D3QN算法流程
-
初始化當前\(Q\)網絡參數\(\theta\),初始化目標\(Q^′\)網絡參數\(\theta^′\),並將\(Q\)網絡參數賦值給\(Q^′\)網絡,\(\theta \to \theta^′\),總迭代輪數\(T\),衰減因子\(\gamma\),探索率\(\epsilon\),目標Q網絡參數更新頻率\(P\),每次隨機采樣的樣本數\(m\)。
-
初始化replay buffer \(D\)
-
for \(t = 1\) to \(T\) do
1) 初始化環境,獲取狀態\(S,R=0,done=Flase\)
2)while True
a)根據狀態\(\phi(S)\)獲取,輸入當前\(Q\)網絡,計算出各個動作對應的Q值,使用\(\epsilon\)-貪婪法選擇當前\(S\)下對應的動作\(A\)
b)執行動作\(A\),得到新的狀態\(S’\)和獎勵\(R\),游戲是否為結束狀態\(done\)
c)將{\(S, S’,A, R, done\)},5個元素存入\(D\)
d)if \(done\)
break
e)從\(D\)中隨機采樣\(m\)個樣本,{\(S_j,S'_j,R_j,A_j,done_j\)},\(j=1,2,3,4...m\),計算當前\(Q\)網絡的\(y_j\):\(y_j=R_j+ \gamma Q^′((\phi(S{_j^′}),\mathop {argmax}_{a^′}Q(\phi(S{_j^′})),a,\theta),\theta^′)\)
f)使用均方損失函數\(\left(\frac{1}{m}\right)\)\(\sum_{r=1}^n\)\((y_j - Q(\phi(S_j),A_j,\theta))^2\),計算loss,反向傳播更新參數\(\theta\)
g)if t % p == 0:\(\theta \to \theta^′\)
h)\(S^′ = S\)
2.2 D3QN的參數調優
-
epslion-Greedy策略,在設置探索率epslion,在不同環境中所選的有很大的“講究”,一般離散的動作比較多,那么epslion就選擇大一些,反之則選擇小一些的,筆者在訓練雅達利游戲Berzerk-ram-v0時,將epslion等於0.1變成0.2之后,學習效率得到了很大的提升。
-
關於網絡結構,筆者認為不能使用過寬的網絡,避免網絡過於冗余,導致出現過擬合現象。網絡的寬度一般不超過\(2^{10}\)。
-
關於replay buffer的容量max數值的容量,一般設置為\(2^{17}\)到\(2^{20}\)。關於采樣采用優先隊列的排列的buffer,筆者正在探索中,在一些問題上並沒有得到比較理想的效果。
-
batch size的選擇,一般都會2的n次方,具體多大的值適合,還需要我們去嘗試。
-
關於gamma的選擇。一般選擇為0.99、0.95、0.995等,切記萬萬不可等於1,等於1就會出現“Q值過大”的風險。
3. 代碼實現
筆者實現了一個簡單的D3QN(Dueling Double DQN)。抱歉並沒有實現Prioritized Replay buffer。
3.1 網絡結構
主要采用全連接網絡,沒有采用卷積。動作選擇也寫在了網絡里面。
import random
from itertools import count
from tensorboardX import SummaryWriter
import gym
from collections import deque
import numpy as np
from torch.nn import functional as F
import torch
import torch.nn as nn
class Dueling_DQN(nn.Module):
def __init__(self, state_dim, action_dim):
super(Dueling_DQN, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.f1 = nn.Linear(state_dim, 512)
self.f2 = nn.Linear(512, 256)
self.val_hidden = nn.Linear(256, 128)
self.adv_hidden = nn.Linear(256, 128)
self.val = nn.Linear(128, 1)
self.adv = nn.Linear(128, action_dim)
def forward(self, x):
x = self.f1(x)
x = F.relu(x)
x = self.f2(x)
x = F.relu(x)
val_hidden = self.val_hidden(x)
val_hidden = F.relu(val_hidden)
adv_hidden = self.adv_hidden(x)
adv_hidden = F.relu(adv_hidden)
val = self.val(val_hidden)
adv = self.adv(adv_hidden)
adv_ave = torch.mean(adv, dim=1, keepdim=True)
x = adv + val - adv_ave
return x
def select_action(self, state):
with torch.no_grad():
# print(state)
Q = self.forward(state)
action_index = torch.argmax(Q, dim=1)
return action_index.item()
3.2 Memory
用於存放經驗
class Memory(object):
def __init__(self, memory_size:int):
self.memory_size = memory_size
self.buffer = deque(maxlen=self.memory_size)
def add(self, experience) -> None:
self.buffer.append(experience)
def size(self):
return len(self.buffer)
def sample(self, batch_size: int, continuous: bool = True):
if batch_size > self.size():
batch_size = self.size()
if continuous:
rand = random.randint(0, len(self.buffer) - batch_size)
return [self.buffer[i] for i in range(rand, rand + batch_size)]
else:
indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
return [self.buffer[i] for i in indexes]
def clear(self):
self.buffer.clear()
3.3 超參數
GAMMA = 0.99
BATH = 256
EXPLORE = 2000000
REPLAY_MEMORY = 50000
BEGIN_LEARN_SIZE = 1024
memory = Memory(REPLAY_MEMORY)
UPDATA_TAGESTEP = 200
learn_step = 0
epsilon = 0.2
writer = SummaryWriter('logs/dueling_DQN2')
FINAL_EPSILON = 0.00001
3.4 主程序
設置優化器,更新網絡參數等
env = gym.make('Berzerk-ram-v0')
n_state = env.observation_space.shape[0]
n_action = env.action_space.n
target_network = Dueling_DQN(n_state, n_action)
network = Dueling_DQN(n_state, n_action)
target_network.load_state_dict(network.state_dict())
optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
r = 0
c = 0
for epoch in count():
state = env.reset()
episode_reward = 0
c += 1
while True:
# env.render()
state = state / 255
p = random.random()
if p < epsilon:
action = random.randint(0, n_action-1)
else:
state_tensor = torch.as_tensor(state, dtype=torch.float).unsqueeze(0)
action = network.select_action(state_tensor)
next_state, reward, done, _ = env.step(action)
episode_reward += reward
memory.add((state, next_state, action, reward, done))
if memory.size() > BEGIN_LEARN_SIZE:
learn_step += 1
if learn_step % UPDATA_TAGESTEP:
target_network.load_state_dict(network.state_dict())
batch = memory.sample(BATH, False)
batch_state, batch_next_state, batch_action, batch_reward, batch_done = zip(*batch)
batch_state = torch.as_tensor(batch_state, dtype=torch.float)
batch_next_state = torch.as_tensor(batch_next_state, dtype=torch.float)
batch_action = torch.as_tensor(batch_action, dtype=torch.long).unsqueeze(0)
batch_reward = torch.as_tensor(batch_reward, dtype=torch.float).unsqueeze(0)
batch_done = torch.as_tensor(batch_done, dtype=torch.long).unsqueeze(0)
with torch.no_grad():
target_Q_next = target_network(batch_next_state)
Q_next = network(batch_next_state)
Q_max_action = torch.argmax(Q_next, dim=1, keepdim=True)
y = batch_reward + target_Q_next.gather(1, Q_max_action)
loss = F.mse_loss(network(batch_state).gather(1, batch_action), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar('loss', loss.item(), global_step=learn_step)
# if epsilon > FINAL_EPSILON: ## 減小探索
# epsilon -= (0.1 - FINAL_EPSILON) / EXPLORE
if done:
break
state = next_state
r += episode_reward
writer.add_scalar('episode reward', episode_reward, global_step=epoch)
if epoch % 100 == 0:
print(f"第{epoch/100}個100epoch的reward為{r / 100}", epsilon)
r = 0
if epoch % 10 == 0:
torch.save(network.state_dict(), 'model/netwark{}.pt'.format("dueling"))
4. 資料
PS:更多技術干貨,快關注【公眾號 | xingzhe_ai】,與行者一起討論吧!