論文筆記之:Deep Attention Recurrent Q-Network


  

Deep Attention Recurrent Q-Network

5vision groups 

 

   摘要:本文將 DQN 引入了 Attention 機制,使得學習更具有方向性和指導性。(前段時間做一個工作打算就這么干,誰想到,這么快就被這幾個孩子給實現了,自愧不如啊( ⊙ o ⊙ ))

    引言:我們知道 DQN 是將連續 4幀的視頻信息輸入到 CNN 當中,那么,這么做雖然取得了不錯的效果,但是,仍然只是能記住這 4 幀的信息,之前的就會遺忘。所以就有研究者提出了 Deep Recurrent Q-Network (DRQN),一個結合 LSTM 和 DQN 的工作:

  1. the fully connected layer in the latter is replaced for a LSTM one , 

  2. only the last visual frame at each time step is used as DQN's input. 

  作者指出雖然只是使用了一幀的信息,但是 DRQN 仍然抓住了幀間的相關信息。盡管如此,仍然沒有看到在 Atari game上有系統的提升。

 

   另一個缺點是:長時間的訓練時間。據說,在單個 GPU 上訓練時間達到 12-14天。於是,有人就提出了並行版本的算法來提升訓練速度。作者認為並行計算並不是唯一的,最有效的方法來解決這個問題。 

  

   最近 visual attention models 在各個任務上都取得了驚人的效果。利用這個機制的優勢在於:僅僅需要選擇然后注意一個較小的圖像區域,可以幫助降低參數的個數,從而幫助加速訓練和測試。對比 DRQN,本文的 LSTM 機制存儲的數據不僅用於下一個 actions 的選擇,也用於 選擇下一個 Attention 區域。此外,除了計算速度上的改進之外,Attention-based models 也可以增加 Deep Q-Learning 的可讀性,提供給研究者一個機會去觀察 agent 的集中區域在哪里以及是什么,(where and what)。

 

 


  

  Deep Attention Recurrent Q-Network:

 

 

    如上圖所示,DARQN 結構主要由 三種類型的網絡構成:convolutional (CNN), attention, and recurrent . 在每一個時間步驟 t,CNN 收到當前游戲狀態 $s_t$ 的一個表示,根據這個狀態產生一組 D feature maps,每一個的維度是 m * m。Attention network 將這些 maps 轉換成一組向量 $v_t = \{ v_t^1, ... , v_t^L \}$,L = m*m,然后輸出其線性組合 $z_t$,稱為 a context vector. 這個 recurrent network,在我們這里是 LSTM,將 context vector 作為輸入,以及 之前的 hidden state $h_{t-1}$,memory state $c_{t-1}$,產生 hidden state $h_t$ 用於:

  1. a linear layer for evaluating Q-value of each action $a_t$ that the agent can take being in state $s_t$ ; 

  2. the attention network for generating a context vector at the next time step t+1. 

 


 

  Soft attention 

  這一小節提到的 "soft" Attention mechanism 假設 the context vector $z_t$ 可以表示為 所有向量 $v_t^i$ 的加權和,每一個對應了從圖像不同區域提取出來的 CNN 特征。權重 和 這個 vector 的重要程度成正比例,並且是通過 Attention network g 衡量的。g network 包含兩個 fc layer 后面是一個 softmax layer。其輸出可以表示為:

  其中,Z是一個normalizing constant。W 是權重矩陣,Linear(x) = Ax + b 是一個放射變換,權重矩陣是A,偏差是 b。我們一旦定義出了每一個位置向量的重要性,我們可以計算出 context vector 為:

  另一個網絡在第三小節進行詳細的介紹。整個 DARQN model 是通過最小化序列損失函數完成訓練:

  其中,$Y_t$ 是一個近似的 target value,為了優化這個損失函數,我們利用標准的 Q-learning 更新規則:

  DARQN 中的 functions 都是可微分的,所以每一個參數都有梯度,整個模型可以 end-to-end 的進行訓練。本文的算法也借鑒了 target network 和 experience replay 的技術。

 


 

  Hard Attention

  此處的 hard attention mechanism 采樣的時候要求僅僅從圖像中采樣一個圖像 patch。

  假設 $s_t$ 從環境中采樣的時候,受到了 attention policy 的影響,attention network g 的softmax layer 給出了帶參數的類別分布(categorical distribution)。然后,在策略梯度方法,策略參數的更新可以表示為:

  其中 $R_t$ 是將來的折扣的損失。為了估計這個值,另一個網絡 $G_t = Linear(h_t)$ 才引入進來。這個網絡通過朝向 期望值 $Y_t$ 進行網絡訓練。Attention network 參數最終的更新采用如下的方式進行:

    其中 $G_t - Y_t$ 是advantage function estimation。

  

  作者提供了源代碼:https://github.com/5vision/DARQN  

  

  實驗部分

  

 

 

 

 


 

  總結:   

 

 

  

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM