上篇文章強化學習——狀態價值函數逼近介紹了價值函數逼近(Value Function Approximation,VFA)的理論,本篇文章介紹大名鼎鼎的DQN算法。DQN算法是 DeepMind 團隊在2015年提出的算法,對於強化學習訓練苦難問題,其開創性的提出了兩個解決辦法,在atari游戲上都有不俗的表現。論文發表在了 Nature 上,此后的一些DQN相關算法都是在其基礎上改進,可以說是打開了深度強化學習的大門,意義重大。
論文地址:Mnih, Volodymyr; et al. (2015). Human-level control through deep reinforcement learning
一、DQN簡介
其實DQN就是 Q-Learning 算法 + 神經網絡。我們知道,Q-Learning 算法需要維護一張 Q 表格,按照下面公式來更新:
然后學習的過程就是更新 這張 Q表格,如下圖所示:

而DQN就是用神經網絡來代替這張 Q 表格,其余相同,如下圖:

其更新方式為:
其中 \(\Delta w\) :
二、Experience replay
DQN 第一個特色是使用 Experience replay ,也就是經驗回放,為何要用經驗回放?還請看下文慢慢詳述
對於網絡輸入,DQN 算法是把整個游戲的像素作為 神經網絡的輸入,具體網絡結構如下圖所示:
第一個問題就是樣本相關度的問題,因為在強化學習過程中搜集的數據就是一個時序的玩游戲序列,游戲在像素級別其關聯度是非常高的,可能只是在某一處特別小的區域像素有變化,其余像素都沒有變化,所以不同時序之間的樣本的關聯度是非常高的,這樣就會使得網絡學習比較困難。
DQN的解決辦法就是 經驗回放(Experience replay)。具體來說就是用一塊內存空間 \(D\) ,用來存儲每次探索獲得數據 \(<s_t, a_t, r_t, s_{t+1}>\) 。然后按照以下步驟重復進行:
- sample:從 \(D\) 中取出一個 batch 的數據 \((s, a, r, s') \in D\)
- 對於取出的樣本數據計算 Target 值:\(r + \gamma\; max_{a'}\hat{Q}(s',a',w)\)
- 使用隨機梯度下降來更新網絡權重 w:
利用經驗回放,可以充分發揮 off-policy 的優勢,behavior policy 用來搜集經驗數據,而 target policy 只專注於價值最大化。
三、Fixed Q targets
第二個問題是前文已經說過的,我們使用 \(\hat{q}(s_t, a_t, w)\) 來代替 TD Target,也就是說在TD Target 中已經包含我了我們要優化的 參數 w。也就是說在訓練的時候 監督數據 target 是不固定的,所以就使得訓練比較困難。
想象一下,我們把 我們要估計的 \(\hat{Q}\) 叫做 Q estimation,把它看做一只貓。把我們的監督數據 Q Target 看做是一只老鼠,現在可以把訓練的過程看做貓捉老鼠的過程(不斷減少之間的距離,類比於我們的 Q estimation 網絡擬合 Q Target 的過程)。現在問題是貓和老鼠都在移動,這樣貓想要捉住老鼠是比較困難的,如下所示:

那么我們讓老鼠在一段時間間隔內不動(固定住),而這期間,貓是可以動的,這樣就比較容易抓住老鼠了。在 DQN 中也是這樣解決的,我們有兩套一樣的網絡,分別是 Q estimation 網絡和 Q Target 網絡。要做的就是固定住 Q target 網絡,那如何固定呢?比如我們可以讓 Q estimation 網路訓練10次,然后把 Q estimation 網絡更新后的參數 w 賦給 Q target 網絡。然后我們再讓Q estimation 網路訓練10次,如此往復下去,試想如果不固定 Q Target 網絡,兩個網絡都在不停地變化,這樣 擬合是很困難的,如果我們讓 Q Target 網絡參數一段時間固定不變,那么擬合過程就會容易很多。下面是 DQN 算法流程圖:
如上圖所示,首先智能體不斷與環境交互,獲取交互數據\(<s,a,r,s'>\)存入replay memory
,當經驗池中有足夠多的數據后,從經驗池中 隨機取出一個 batch_size
大小的數據,利用當前網絡計算 Q的預測值,使用 Q-Target 網絡計算出 Q目標值,然后計算兩者之間的損失函數,利用梯度下降來更新當前 網絡參數,重復若干次后,把當前網絡的參數 復制給 Q-Target 網絡。
關於DQN的實現代碼部分我們下篇介紹
參考資料:
- B站 周老師的 強化學習綱要第四課下