在上一篇文章中介紹了MDP與Bellman方程,MDP可以對強化學習的問題進行建模,Bellman提供了計算價值函數的迭代公式。但在實際問題中,我們往往無法准確獲知MDP過程中的轉移概率$P$,因此無法直接將解決 MDP 問題的經典思路 value iteration 和 policy iteration 應用到解決強化學習的問題上。為了將轉移概率以逼近實際情況的方式計算出來,基於value iteration的Q-Learning算法應運而生,它通過在迭代過程中不斷更新Q-table的方式來近似轉移概率矩陣$P$。此外,Sarsa還可以online的形式學習,區別在於與Q-Learning的迭代過程不同。最后,本文還將介紹DQN (Deep Q-Learning Network),當轉移矩陣P (Q-table) 過大時計算困難甚至無法計算,而DQN其利用Deep網絡結構擬合Q-table,使得Q-Learning框架具備了解決狀態無限(動作仍舊有限)的強化學習問題。
1. Q-Learning
Q-Learning 是一個強化學習中一個很經典的算法,其出發點很簡單,就是用一張表存儲在各個狀態下執行各種動作能夠帶來的 reward,如下表表示了有兩個狀態 $s1$,$s2$,每個狀態下有兩個動作 $a1$,$a2$, 表格里面的值表示 reward
reward | action 1 | action 2 |
state 1 | -1 | 2 |
state 2 | -5 | 2 |
這個表就是 Q-Table,里面的每個值定義為$ Q(s,a) $, 表示在狀態$ s $下執行動作$ a $ 所獲取的reward,那么選擇的時候可以采用一個貪婪的做法,即選擇價值最大的那個動作去執行。
Q-Table通過隨機初始化來生成初始表格,然后通過不斷執行動作獲取環境的反饋並通過算法更新 Q-Table。下面重點講如何通過算法更新 Q-Table。
當我們處於某個狀態$ s $時,根據 Q-Table 的值選擇的動作$ a $, 那么從表格獲取的 reward 為 $ Q(s,a) $,此時的 reward 並不是我們真正的獲取的 reward,而是預期獲取的 reward,那么真正的 reward 在哪?我們知道執行了動作$ a $並轉移到了下一個狀態$ s′ $時,能夠獲取一個即時的 reward(記為$ r $), 但是除了即時的 reward,還要考慮所轉移到的狀態 $ s′ $ 對未來期望的reward,因此真實的 reward (記為$ Q′(s,a) $)由兩部分組成:即時的 reward 和未來期望的 reward,且未來的 reward 往往是不確定的,因此需要加個折扣因子$ \gamma $,則真實的 reward 表示如下:
$$ Q’(s,a) = r + \gamma\max_{a’}Q(s’,a’) $$
$ \gamma $的值一般設置為 0 到 1 之間,設為0時表示只關心即時回報,設為 1 時表示未來的期望回報跟即時回報一樣重要。
有了真實的 reward 和預期獲取的 reward,可以很自然地想到用 supervised learning那一套,求兩者的誤差然后進行更新,在 Q-learning 中也是這么干的,更新的值則是原來的$ Q(s, a) $,更新規則如下:
$$ Q(s, a) = Q(s, a) + \alpha(Q’(s, a) - Q(s,a)) $$
更新規則跟梯度下降非常相似,這里的$ \alpha $可理解為學習率。更新規則也很簡單,可是這一類采用了貪心思想的算法往往都會有這么一個問題:算法是否能夠收斂,是收斂到局部最優還是全局最優?
關於收斂性,可以參考 Convergence of Q-learning: a simple proof,這個文檔 證明了這個算法能夠收斂,且根據知乎上這個問題 RL兩大類算法的本質區別?(Policy Gradient 和 Q-Learning),原始的 Q-Learning 理論上能夠收斂到最優解,但是通過 Q 函數近似 Q-Table 的方法則未必能夠收斂到最優解(如DQN)。
除此之外, Q-Learning 中還存在着探索與利用(Exploration and Exploition)的問題, 大致的意思就是不要每次都遵循着當前看起來是最好的方案,而是會選擇一些當前看起來不是最優的策略,這樣也許會更快探索出更優的策略。Exploration and Exploition 的做法很多,Q-Learning 采用了最簡單的$ \epsilon-greedy $, 就是每次有$ \epsilon $的概率是選擇當前 Q-Table 里面值最大的action的,$ 1 - \epsilon $的概率是隨機選擇策略的。
Q-Learning 算法的流程如下,圖片摘自這里:
上面的流程中的 Q 現實 就是上面說的 $ Q′(s,a) $, Q 估計就是上面說的$ Q(s,a) $。
下面的 python 代碼演示了更新通過 Q-Table 的算法, 參考了這個 repo 上的代碼,初始化主要是設定一些參數,並建立 Q-Table, choose_action 是根據當前的狀態 observation,並以 $ \epsilon-greedy $ 的策略選擇當前的動作; learn 則是更新當前的 Q-Table,check_state_exist 則是檢查當前的狀態是否已經存在 Q-Table 中,若不存在要在 Q-Table 中創建相應的行。

1 import numpy as np 2 import pandas as pd 3 4 class QTable: 5 def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): 6 self.actions = actions # a list 7 self.lr = learning_rate 8 self.gamma = reward_decay 9 self.epsilon = e_greedy 10 self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) 11 12 def choose_action(self, observation): 13 self.check_state_exist(observation) 14 # action selection 15 if np.random.uniform() < self.epsilon: 16 # choose best action 17 state_action = self.q_table.ix[observation, :] 18 state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value 19 action = state_action.argmax() 20 else: 21 # choose random action 22 action = np.random.choice(self.actions) 23 return action 24 25 def learn(self, s, a, r, s_): 26 self.check_state_exist(s_) 27 q_predict = self.q_table.ix[s, a] 28 if s_ != 'terminal': 29 q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal 30 else: 31 q_target = r # next state is terminal 32 self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update 33 34 def check_state_exist(self, state): 35 if state not in self.q_table.index: 36 # append new state to q table 37 self.q_table = self.q_table.append( 38 pd.Series( 39 [0]*len(self.actions), 40 index=self.q_table.columns, 41 name=state, 42 ) 43 )
2. Sarsa
Sarsa 跟 Q-Learning 非常相似,也是基於 Q-Table 進行決策的。不同點在於決定下一狀態所執行的動作的策略,Q-Learning 在當前狀態更新 Q-Table 時會用到下一狀態Q值最大的那個動作,但是下一狀態未必就會選擇那個動作;但是 Sarsa 會在當前狀態先決定下一狀態要執行的動作,並且用下一狀態要執行的動作的 Q 值來更新當前狀態的 Q 值;說的好像很繞,但是看一下下面的流程便可知道這兩者的具體差異了,圖片摘自這里
那么,這兩者的區別在哪里呢?這篇文章里面是這樣講的
This means that SARSA takes into account the control policy by which the agent is moving, and incorporates that into its update of action values, where Q-learning simply assumes that an optimal policy is being followed.
簡單來說就是 Sarsa 在執行action時會考慮到全局(如更新當前的 Q 值時會先確定下一步要走的動作), 而 Q-Learning 則顯得更加的貪婪和”短視”, 每次都會選擇當前利益最大的動作(不考慮 $ \epsilon-greedy $),而不考慮其他狀態。
那么該如何選擇,根據這個問題:When to choose SARSA vs. Q Learning,有如下結論
If your goal is to train an optimal agent in simulation, or in a low-cost and fast-iterating environment, then Q-learning is a good choice, due to the first point (learning optimal policy directly). If your agent learns online, and you care about rewards gained whilst learning, then SARSA may be a better choice.
簡單來說就是如果要在線學習,同時兼顧 reward 和總體的策略(如不能太激進,agent 不能很快掛掉),那么選擇 Sarsa;而如果沒有在線的需求的話,可以通過 Q-Learning 線下模擬找到最好的 agent。所以也稱 Sarsa 為on-policy,Q-Learning 為 off-policy。
3. DQN
我們前面提到的兩種方法都以依賴於 Q-Table,但是其中存在的一個問題就是當 Q-Table 中的狀態比較多,可能會導致整個 Q-Table 無法裝下內存。因此,DQN 被提了出來,DQN 全稱是 Deep Q Network,Deep 指的是通的是深度學習,其實就是通過神經網絡來擬合整張 Q-Table。
DQN 能夠解決狀態無限,動作有限的問題;具體來說就是將當前狀態作為輸入,輸出的是各個動作的 Q 值。以 Flappy Bird 這個游戲為例,輸入的狀態近乎是無限的(當前 bird 的位置和周圍的水管的分布位置等),但是輸出的動作只有兩個(飛或者不飛)。實際上,已經有人通過 DQN 來玩這個游戲了,具體可參考這個 DeepLearningFlappyBird
所以在 DQN 中的核心問題在於如何訓練整個神經網絡,其實訓練算法跟 Q-Learning 的訓練算法非常相似,需要利用 Q 估計和 Q 現實的差值,然后進行反向傳播。
這里放上提出 DQN 的原始論文 Playing atari with deep reinforcement learning 中的算法流程圖
上面的算法跟 Q-Learning 最大的不同就是多了 Experience Replay 這個部分,實際上這個機制做的事情就是先進行反復的實驗,並將這些實驗步驟獲取的 sample 存儲在 memory 中,每一步就是一個 sample,每個sample是一個四元組,包括:當前的狀態,當前狀態的各種action的 Q 值,當前采取的action獲得的即時回報,下一個狀態的各種action的Q值。拿到這樣一個 sample 后,就可以根據上面提到的 Q-Learning 更新算法來更新網絡,只是這時候需要進行的是反向傳播。
Experience Replay 機制的出發點是按照時間順序所構造的樣本之間是有關的(如上面的$ \phi(s_{t+1}) $ 會受到$ \phi(s_{t}) $的影響)、非靜態的(highly correlated and non-stationary),這樣會很容易導致訓練的結果難以收斂。通過 Experience Replay 機制對存儲下來的樣本進行隨機采樣,在一定程度上能夠去除這種相關性,進而更容易收斂。當然,這種方法也有弊端,就是訓練的時候是 offline 的形式,無法做到 online 的形式。
除此之外,上面算法流程圖中的 aciton-value function 就是一個深度神經網絡,因為神經網絡是被證明有萬有逼近的能力的,也就是能夠擬合任意一個函數;一個 episode 相當於 一個 epoch;同時也采用了$ \epsilon-greedy $策略。代碼實現可參考上面 FlappyBird 的 DQN 實現。
上面提到的 DQN 是最原始的的網絡,后面Deepmind 對其進行了多種改進,比如說 Nature DQN 增加了一種新機制 separate Target Network,就是計算上圖的$ y_j $ 的時候不采用網絡 $ Q $, 而是采用另外一個網絡(也就是 Target Network) $ Q′ $, 原因是上面計算$ y_j $和 Q 估計都采用相同的網絡$ Q $,這樣使得$ Q $大的樣本,$ y $也會大,這樣模型震盪和發散可能性變大,其原因其實還是兩者的關聯性較大。而采用另外一個獨立的網絡使得訓練震盪發散可能性降低,更加穩定。一般$ Q′ $會直接采用舊的$ Q $, 比如說 10 個 epoch 前的$ Q $.
除此之外,大幅度提升 DQN 玩 Atari 性能的主要就是 Double DQN,Prioritised Replay 還有 Dueling Network 三大方法;這里不詳細展開,有興趣可參考這兩篇文章:DQN從入門到放棄6 DQN的各種改進 和 深度強化學習(Deep Reinforcement Learning)入門:RL base & DQN-DDPG-A3C introduction。
綜上,本文介紹了強化學習中基於 value 的方法:包括 Q-Learning 以及跟 Q-Learning 非常相似的 Sarsa,同時介紹了通過 DQN 解決狀態無限導致 Q-Table過大的問題。需要注意的是 DQN 只能解決動作有限的問題,對於動作無限或者說動作取值為連續值的情況,需要依賴於 policy gradient 這一類算法,而這一類算法也是目前更為推崇的算法,在下一章將介紹 Policy Gradient 以及結合 Policy Gradient 和 Q-Learning 的 Actor-Critic 方法。
參考: