為什么有pad和pack操作?
先看一個例子,這個batch中有5個sample
如果不用pack和pad操作會有一個問題,什么問題呢?比如上圖,句子“Yes”只有一個單詞,但是padding了多余的pad符號,這樣會導致LSTM對它的表示通過了非常多無用的字符,這樣得到的句子表示就會有誤差,更直觀的如下圖:
那么我們正確的做法應該是怎么樣呢?
在上面這個例子,我們想要得到的表示僅僅是LSTM過完單詞"Yes"之后的表示,而不是通過了多個無用的“Pad”得到的表示:如下圖:
torch.nn.utils.rnn.pack_padded_sequence()
這里的pack
,理解成壓緊比較好。 將一個 填充過的變長序列 壓緊。(填充時候,會有冗余,所以壓緊一下)
其中pack的過程為:(注意pack的形式,不是按行壓,而是按列壓)
(下面方框內為PackedSequence
對象,由data和batch_sizes組成)
pack之后,原來填充的 PAD(一般初始化為0)占位符被刪掉了。
輸入的形狀可以是(T×B×* )。T
是最長序列長度,B
是batch size
,*
代表任意維度(可以是0)。如果batch_first=True
的話,那么相應的 input size
就是 (B×T×*)
。
Variable
中保存的序列,應該按序列長度的長短排序,長的在前,短的在后。即input[:,0]
代表的是最長的序列,input[:, B-1]
保存的是最短的序列。
NOTE:
只要是維度大於等於2的input
都可以作為這個函數的參數。你可以用它來打包labels
,然后用RNN
的輸出和打包后的labels
來計算loss
。通過PackedSequence
對象的.data
屬性可以獲取 Variable
。
參數說明:
- input (Variable) – 變長序列 被填充后的 batch
- lengths (list[int]) –
Variable
中 每個序列的長度。 - batch_first (bool, optional) – 如果是
True
,input的形狀應該是B*T*size
。
返回值:
一個PackedSequence
對象。
torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence
。
上面提到的函數的功能是將一個填充后的變長序列壓緊。 這個操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來。填充時會初始化為0。
返回的Varaible的值的size
是 T×B×*
, T
是最長序列的長度,B
是 batch_size,如果 batch_first=True
,那么返回值是B×T×*
。
Batch中的元素將會以它們長度的逆序排列。
參數說明:
- sequence (PackedSequence) – 將要被填充的 batch
- batch_first (bool, optional) – 如果為True,返回的數據的格式為
B×T×*
。
返回值: 一個tuple,包含被填充后的序列,和batch中序列的長度列表
一個例子:
輸出:(這個輸出結果能較為清楚地看到中間過程)
此時PackedSequence對象輸入RNN后,輸出RNN的還是PackedSequence對象
(最后一個unpacked沒有用batch_first, 所以。。。)
參考:
https://www.cnblogs.com/lindaxin/p/8052043.html
https://pytorch.org/docs/stable/nn.html?highlight=pack_padded_sequence#torch.nn.utils.rnn.pack_padded_sequence