pytorch中的pack_padded_sequence和pad_packed_sequence用法


pack_padded_sequence是將句子按照batch優先的原則記錄每個句子的詞,變化為不定長tensor,方便計算損失函數。

pad_packed_sequence是將pack_padded_sequence生成的結構轉化為原先的結構,定長的tensor。

其中test.txt的內容

As they sat in a nice coffee shop, 
he was too nervous to say anything and she felt uncomfortable. 
Suddenly, he asked the waiter, 
"Could you please give me some salt? I'd like to put it in my coffee."

具體參見如下代碼

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq

vocab = {}
token_id = 1
lengths = []

#讀取文件,生成詞典
with open('test.txt', 'r') as f:
    lines=f.readlines()
    for line in lines:
        tokens = wordfreq.tokenize(line.strip(), 'en')
        lengths.append(len(tokens))
        #將每個詞加入到vocab中,並同時保存對應的index
        for word in tokens:
            if word not in vocab:
                vocab[word] = token_id
                token_id += 1

x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#將詞轉化為數字
with open('test.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        tokens = wordfreq.tokenize(line.strip(), 'en')
        for i in range(len(tokens)):
            x[l_no, i] = vocab[tokens[i]]
        l_no += 1

x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11.,  5., 14.])

_, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11.,  8.,  5.])
print(idx_sort)#tensor([3, 1, 0, 2])

lengths = list(lengths[idx_sort])#按下標取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下標取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24.,  9.,  1., 20., 25., 10.,  2.,  9., 26., 11.,  3., 21., 27., 12.,
         4., 22., 28., 13.,  5., 23., 29., 14.,  6., 30., 15.,  7., 31., 16.,
         8., 32., 17., 13., 18., 33., 19., 34.,  4.,  7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
'''


x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), tensor([14, 11,  8,  5]))
'''
#還原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
'''

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM