小萌新在看pytorch官網 LSTM代碼時 對batch_first 參數 和torch.nn.utils.rnn.pack_padded_sequence 不太理解,
在回去苦學了一番 ,將自己消化過的記錄在這,希望能幫到跟我有同樣迷惑的伙伴
官方API:https://pytorch.org/docs/stable/nn.html?highlight=lstm#torch.nn.LSTM
- 參數
– input_size
– hidden_size
– num_layers
– bias
– batch_first
– dropout
– bidirectional - 特別說下batch_first ,參數默認為False,也就是它鼓勵我們第一維不是batch,這與我們常規輸入想悖,畢竟我們習慣的輸入是(batch, seq_len, hidden_size),那么官方為啥會 這樣子設置呢?

先不考慮hiddem_dim,左邊圖矩陣維度為batch_size * max_length, 6個序列,沒個序列填充到最大長度4,經過轉置后得到max_length *batch_size , 右圖標藍的一列 對應的就是 左圖第二列,而左圖第二列表示的是 每個序列里面第二個token,這樣子有什么好處呢?相當於可以並行處理 每個句子在time step下時刻的計算,這樣就 可以並行過LSTM,從而一定程度上提高處理速度。因為官網放的圖例子 里面數字都是句子token 索引化之后的,反而讓人容易看暈,因而小萌新自己畫了個好理解的圖。一起看下圖呀。
一共有3個句子,最大長度為6,我們之前習慣的是 按行看,我們現在按一列一列來看(就相當於轉置啦)

time step 0接受的是[ZHAOJIAN girls eat];
time step1接收的是[and are apple];
time step2接受的是[YUQIN beautiful PAD];
time step3接收的是[are angles PAD];
以此類推。 現在這3個句子就可以並行過LSTM
pad_sequence
我們知道一個batch里的序列長度是不一致的,而LSTM是無法處理長度不同的序列的,需要pad操作用0把它們都填充成max_length長度。下圖有3個句子,以最長的句子長度 6 作為max_length,其余句子都填充到max_length 。這是PAD的作用,很好理解。
from torch.nn.utils.rnn import pack_padded_sequence ,pad_sequence ,pack_sequence inputs = ["LIHUA went to The TsinghUA University", "Liping went to technical school ", "I work in the mall ", "we both have bright future"] inputs.sort(key=lambda x:len(x.split()),reverse=True) batch_size=len(inputs) max_length=len(inputs[0].split()) lengths=[len(s.split()) for s in inputs] word_to_idx={} for sen in inputs: for word in sen.split(): if word not in word_to_idx: word_to_idx[word]=len(word_to_idx) idx=[] for sentence in inputs: a=[word_to_idx[w] for w in sentence.split()] idx.append(a) pprint(idx) padded_sequence = pad_sequence([torch.FloatTensor(id) for id in idx], batch_first=True) print(padded_sequence) packed_sequence = pack_sequence([torch.FloatTensor(id) for id in idx]) # packed_sequence是PackedSequence的實例
pack_sequence
但帶來一個問題 ,什么問題呢? 對於長度小於MAX_LENGTH ,經過PAD填充操作后的句子,會導致LSTM對它的表示多了很多無用的字符,如下圖所示,我們希望的是在最后一個有用token 就輸入句子的向量表示,而不是在很多PAD后才輸入句子表示,這是pack就派上場了,可以理解成 將一個填充過的變長序列壓緊.壓縮的對象就是 padded suquence, 壓縮后的輸入將不含 0
看圖更好理解 哦

那么,聰明如你肯定會覺得不對勁,這先 填充又 壓緊, 這不是做無用功?其實不是哦,因為 pack 后可並不是一個簡單的 Tensor 類型的數據,而是一個 ”PackedSequence“ 類型的 object,可以直接傳給RNN。小萌新在苦逼地看RNN源碼時,發現forward 函數 里上來就是判斷 輸入是否是 PackedSequence 的實例,進而采取不同的操作。如果輸入是 PackedSequence,輸出也是該類型。這里的輸出類型都指的是 forward 函數的第一個返回值(每個time step 對應的hidden_state),第二個返回值(最后一個time step對應的hidden_state)的類型不管輸入是不是 PackedSequence 類型,都是一樣的。
pack_padded_sequence
pytorch里 有封裝的更好的 :torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
接下來說說這些參數作用
lengths :該參數中各句子長度值的順序要和對應的輸入中的序列順序一致
enforce_sorted: 默認值是 True,表示輸入已經按句子長度降序排好序。如果輸入在 pad 時沒有順序,那么此時在此處需要設置該值為 False,那么函數會再去排序
返回的對象是PackedSequence object。該類型的變量便可以直接喂給 RNN/LSTM等。
torch.nn.utils.rnn.pad_packed_sequence():之前的pack_padded_sequence 是先補齊到相同長度 再壓緊,這個當然就是反過來,對壓緊后的序列 進行擴充補齊操作。
注意:inputs是否排好序和 lengths參數和enforce_sorted 一定要對應起來。小萌新習慣將 inputs 按照長度先排好序,再將length 排好序enforce_sorted參數不去動它。
inputs.sort(key=lambda x:len(x.split()),reverse=True)
lengths=[len(s.split()) for s in inputs]

------------恢復內容結束------------
