tf實現LSTM時rnn.DropoutWrapper


轉自:https://blog.csdn.net/abclhq2005/article/details/78683656 作者:abclhq2005 

1.概念介紹

所謂dropout,就是指網絡中每個單元在每次有數據流入時以一定的概率(keep prob)正常工作,否則輸出0值。這是是一種有效的正則化方法,可以有效防止過擬合。

在rnn中進行dropout時,對於rnn的部分不進行dropout,也就是說從t-1時候的狀態傳遞到t時刻進行計算時,這個中間不進行memory的dropout;僅在同一個t時刻中,多層cell之間傳遞信息的時候進行dropout,如下圖所示

 上圖中,t-2時刻的輸入xt−2首先傳入第一層cell,這個過程有dropout,但是從t−2時刻的第一層cell傳到t−1,t,t+1的第一層cell這個中間都不進行dropout。再從t+1時候的第一層cell向同一時刻內后續的cell傳遞時,這之間又有dropout。

 2.用法

在使用tf.nn.rnn_cell.DropoutWrapper時,同樣有一些參數,例如input_keep_prob,output_keep_prob等,分別控制輸入和輸出的dropout概率,很好理解。
可以從官方文檔中看到,它有input_keep_prob和output_keep_prob,也就是說裹上這個DropoutWrapper之后,如果我希望是input傳入這個cell時dropout掉一部分input信息的話,就設置input_keep_prob,那么傳入到cell的就是部分input;如果我希望這個cell的output只部分作為下一層cell的input的話,就定義output_keep_prob。
備注:Dropout只能是層與層之間(輸入層與LSTM1層、LSTM1層與LSTM2層)的Dropout;同一個層里面,T時刻與T+1時刻是不會Dropout的。

3.參數

 

__init__(
    cell,
    input_keep_prob=1.0,
    output_keep_prob=1.0,
    state_keep_prob=1.0,
    variational_recurrent=False,
    input_size=None,
    dtype=None,
    seed=None,
    dropout_state_filter_visitor=None
)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM