強化學習原理源碼解讀002:DQN


目錄

  Policy based方法 vs Value based方法

  策略網絡

  算法總體流程

  如何通過對回歸任務的優化來更新Q網絡

  為什么不可以同時更新Q網絡和目標網絡

  為什么要使用帶有探索策略的Q函數

  探索策略的數學表達

  ReplayBuffer的作用

  Q值被高估的問題

  源碼實現

  參考資料


DQN是Deep Q Network的縮寫,由Google Deep mind 團隊提出。

Policy based方法 vs Value based方法

 

上一篇文章中提到的Policy Gradient屬於Policy based的RL學習方法。

本文介紹的DQN屬於Value based的RL學習方法。

兩者區別:

Policy based是直接對累計獎勵值進行最大化求解,在實做過程中,在很多任務中是訓練不出比較好的智能體的;

而Value based方法是不直接對累計獎勵值進行最大化求解,而是設置一個價值函數(狀態或動作)來評價當前智能體到最后獲得獎勵值的期望,通過這種評價,再建立優化方案,從而達到對總體較優累計獎勵值的求解。狀態價值函數(State value)記為,動作價值函數(State-action value)記為

 

 返回目錄

 

策略網絡

self.fc1 = nn.Linear(4, 128)

self.fc2 = nn.Linear(128, 128)

self.fc3 = nn.Linear(128, 2)

 返回目錄

 

算法總體流程

 

我們針對其中的幾個要點進行展開:

■如何通過對回歸任務的優化來更新Q網絡

■為什么不可以同時更新Q網絡和目標網絡

■為什么要使用帶有探索策略的Q函數

■探索策略的數學表達

■ReplayBuffer的作用

■Q值被高估的問題

 返回目錄

 

如何通過對回歸任務的優化來更新Q網絡

假設我們收集到的某一筆數據為

原始Q網絡計算在狀態下執行動作,產生輸出

目標Q網絡計算在狀態下執行動作,產生輸出

那么,就根據構建適用於回歸的損失函數,更新時只更新原始Q網絡,一段時間之后使用原始Q網絡的參數覆蓋目標Q網絡 。 

 返回目錄

 

為什么不可以同時更新Q網絡和目標網絡

實驗表明,同時更新兩個網絡會出現學習不穩定的情況。

 返回目錄

 

為什么要使用帶有探索策略的Q函數

當我們使用Q函數的時候,我們的π完全依賴於Q函數,窮舉每一個a,看哪一個可以讓Q最大。

這和policy Gradient不一樣,在做PG的時候,我們輸出是隨機的,我們輸出一個動作的分布,然后采樣一個動作,所以在PG里每一次采取的動作是有隨機性的。

很顯然,剛開始估出來的Q函數是可靠的,假設有一個動作得到過獎勵,那未來會一直采樣這個動作。

例子1:你去了一個餐廳,點了一盤椒麻雞,感覺好吃,以后去這個餐廳就一直點椒麻雞,就不去探索是不是有更好吃的東西了。

例子2:玩貪吃蛇時,某一次向上走吃到了一個星星,那他以后就一直認為向上走是最好的,以至於很快就撞牆死掉。

 返回目錄

 

探索策略的數學表達

列舉兩種方式:

方式一:Epsilon-Greedy

 

ε會隨着時間的推移,逐漸變小。因為剛開始的時候需要更多的探索,當Q學習得比較不錯的時候,就可以減少探索的概率。

方式二:Boltzmann Exploration

 

剛開始是一個均勻分布,后來價值高的動作采樣到的概率越來越高。

其實還有比較高級的Noisy Net的方式

 返回目錄

 

ReplayBuffer的作用

現在有一個智能體π和環境做互動來收集數據,我們會把所有的數據放在一個buffer里面,假設里面可以存5w個資料,每一筆資料就是一個四元組

這里面的數據,可能來自於不同的策略。這個buffer只有在裝滿之后才會把舊的資料丟棄。

更新Q函數時,就從buffer中隨機抽一個batch,然后去訓練更新。

現在其實就變成了off-policy的,因為我們的Q本來要觀察π的價值的,但是存在buffer里的經驗,不是統統來自於π,有一些是過去的π遺留下來的。

好處:

1.在做強化學習的時候,往往耗時的是在於和環境做互動,訓練的過程往往速度比較快,用了buffer可以減少和環境做互動的次數,因為在做訓練的時候,經驗不需要統統來自某一個π,一些過去的經驗也可以放在buffer里被使用很多次。

2.在訓練網絡的時候,我們希望一個batch里面的數據越不同越好,如果batch里的數據都是同樣性質的,訓練下去是容易壞掉的。

問題:我們明明是要觀察π的價值,里面混雜了一些不是π的經驗,到底有沒有關系?

一個簡單的解釋:這些π差的並不多,太老會自動舍棄的,所以沒有關系。

 返回目錄

 

Q值被高估的問題

在算target的時候,我們實際上在做的事情是,看哪一個a可以得到最大的Q值,就把他加上去作為target,假設有某一個動作他得到的值是被高估的,所以很大概率會選到那些值被高估的動作的值當做max的結果,再加上rt當做target,所以target總會太大。

解決方法:

最簡單的方式就是對target的乘以一個小數

復雜的做法:

Double DQN

 返回目錄

 

源碼實現

 代碼

import gym
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32

class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)

    def size(self):
        return len(self.buffer)

class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else :
            return out.argmax().item()

def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()
    x = []
    y = []

    print_interval = 20
    score = 0.0
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    for n_epi in range(5000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s = env.reset()
        done = False

        while not done:
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break

        if memory.size()>2000 and score<500*print_interval:
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            x.append(n_epi)
            y.append(score / print_interval)
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            score = 0.0
    env.close()

    env.close()

    plt.plot(x, y)
    plt.savefig('pic_saved/res_dqn.jpg')
    plt.show()


if __name__ == '__main__':
    main()
View Code

 

效果如下圖所示,橫坐標表示訓練輪數,縱坐標表示智能體平均得分,游戲滿分500分

 返回目錄

 

參考資料

https://github.com/seungeunrho/minimalRL

https://www.bilibili.com/video/BV1UE411G78S?from=search&seid=10996250814942853843

 

 返回目錄


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM