Pytorch中的RNN之pack_padded_sequence()和pad_packed_sequence()


為什么有pad和pack操作?

先看一個例子,這個batch中有5個sample

 

如果不用pack和pad操作會有一個問題,什么問題呢?比如上圖,句子“Yes”只有一個單詞,但是padding了多余的pad符號,這樣會導致LSTM對它的表示通過了非常多無用的字符,這樣得到的句子表示就會有誤差,更直觀的如下圖:

那么我們正確的做法應該是怎么樣呢?

在上面這個例子,我們想要得到的表示僅僅是LSTM過完單詞"Yes"之后的表示,而不是通過了多個無用的“Pad”得到的表示:如下圖:

 

torch.nn.utils.rnn.pack_padded_sequence()

這里的pack,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗余,所以壓緊一下)

其中pack的過程為:(注意pack的形式,不是按行壓,而是按列壓)

(下面方框內為PackedSequence對象,由data和batch_sizes組成)

pack之后,原來填充的 PAD(一般初始化為0)占位符被刪掉了。

輸入的形狀可以是(T×B×* )。T是最長序列長度,Bbatch size*代表任意維度(可以是0)。如果batch_first=True的話,那么相應的 input size 就是 (B×T×*)

Variable中保存的序列,應該按序列長度的長短排序,長的在前,短的在后。即input[:,0]代表的是最長的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是維度大於等於2的input都可以作為這個函數的參數。你可以用它來打包labels,然后用RNN的輸出和打包后的labels來計算loss。通過PackedSequence對象的.data屬性可以獲取 Variable

參數說明:

  • input (Variable) – 變長序列 被填充后的 batch
  • lengths (list[int]) – Variable 中 每個序列的長度。
  • batch_first (bool, optional) – 如果是True,input的形狀應該是B*T*size

返回值:

一個PackedSequence 對象。

torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence

上面提到的函數的功能是將一個填充后的變長序列壓緊。 這個操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來。填充時會初始化為0。

返回的Varaible的值的size是 T×B×*T 是最長序列的長度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*

Batch中的元素將會以它們長度的逆序排列。

參數說明:

  • sequence (PackedSequence) – 將要被填充的 batch
  • batch_first (bool, optional) – 如果為True,返回的數據的格式為 B×T×*

返回值: 一個tuple,包含被填充后的序列,和batch中序列的長度列表

一個例子:

輸出:(這個輸出結果能較為清楚地看到中間過程)

此時PackedSequence對象輸入RNN后,輸出RNN的還是PackedSequence對象

(最后一個unpacked沒有用batch_first,  所以。。。)

 

參考:

https://www.cnblogs.com/lindaxin/p/8052043.html

https://pytorch.org/docs/stable/nn.html?highlight=pack_padded_sequence#torch.nn.utils.rnn.pack_padded_sequence

https://zhuanlan.zhihu.com/p/34418001?edition=yidianzixun&utm_source=yidianzixun&yidian_docid=0IVwLf60


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM