Pytorch 解決lstm輸入可變長序列問題


我是做圖像到語音音素識別相關的科研工作的,需要用到lstm識別一個序列幀所對應的音素,但是不同音素有不同長度的幀,所以需要解決變長序列的問題。

需要解決這個問題的原因是:不等長的序列無法進行batch

我主要參考的文章:https://blog.csdn.net/kejizuiqianfang/article/details/100835528

但我又不能完全套用這篇文章的方法。這篇文章舉例用的是無標簽的數據集,但我們很可能處理有便簽的數據集,所以這篇文章主要說明如何處理有標簽的數據集

 

主要使用的也是這三個方法:

torch.nn.utils.rnn.pad_sequence()  把一個batch列表(注意必須是列表,不能是元組)中不等長的tensor補充成等長的tensor后返回batch的tensor

torch.nn.utils.rnn.pack_padded_sequence()  把pad后的tensor壓縮成各個序列的實際長度,同時數據變成PackedSequence類型

torch.nn.utils.rnn.pad_packed_sequence()  把上面所壓縮成PackedSequence的數據還原成tensor類型, 並補成等長的數據

注意不要忘記 import torch.nn.utils.rnn as rnn_utils

 

這次我舉一個有標簽數據的例子

train_x = [(torch.tensor([1, 2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([3, 4, 5, 6, 7]), torch.tensor(1)),
            (torch.tensor([4, 5, 6, 7]), torch.tensor(1)),
             (torch.tensor([5, 6, 7]), torch.tensor(1)),
              (torch.tensor([6, 7]), torch.tensor(1)),
               (torch.tensor([7]), torch.tensor(1))]

這里的train_x是已經經過class myDataset(Data.Dataset)返回的數據格式。我們需要注意的是train_x這個列表中是一個一個元組,元組的第一個元素是序列tensor,第二個元素是標簽tensor。
物品參考的文章中與這不同的是列表里存的數一個一個序列tensor。
如果按照上文那樣pad會報一個錯誤:

 

 所以,在處理collate_fn函數時需要注意把元組中的序列和標簽分別取出來組成list。注意在排序的時候要對train_data排序,主要要和標簽一起排

def collate_fn(train_data):
    # print("排序前:", train_data)
    train_data.sort(key=lambda data: len(data[0]), reverse=True)
    # print("排序后:", train_data)
    train_x = []
    train_y = []
    for data in train_data:
        train_x.append(data[0])
        train_y.append(data[1])

    data_length = [len(data) for data in train_x]
    print(data_length)
    train_x = rnn_utils.pad_sequence(train_x, batch_first=True, padding_value=0)
    train_y = torch.from_numpy(np.asarray(train_y))
    return train_x.unsqueeze(-1), train_y, data_length 

 

在輸入到lstm之前,需要進行壓縮,提高效率

data = rnn_utils.pack_padded_sequence(data_x, length, batch_first=True)
    output= net(data.float())

 

如果有全連接層,需要在進入全連接層之前,進行pad_packed_sequence

下面是整體的代碼:

import torch.nn.utils.rnn as rnn_utils
import torch
import torch.utils.data as Data
import torch.nn as nn
import numpy as np

train_x = [(torch.tensor([1, 2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([3, 4, 5, 6, 7]), torch.tensor(1)),
            (torch.tensor([4, 5, 6, 7]), torch.tensor(1)),
             (torch.tensor([5, 6, 7]), torch.tensor(1)),
              (torch.tensor([6, 7]), torch.tensor(1)),
               (torch.tensor([7]), torch.tensor(1))]

def collate_fn(train_data):
    train_x.sort(key=lambda data: len(data[0]), reverse=True)
   train_x = []
    train_y = []
    for data in train_data:
        train_x.append(data[0])
        train_y.append(data[1])

    data_length = [len(data) for data in train_x]
    # print(data_length)
    train_x = rnn_utils.pad_sequence(train_x, batch_first=True, padding_value=0)
    train_y = torch.from_numpy(np.asarray(train_y))
    return train_x.unsqueeze(-1), train_y, data_length  # 對train_data增加了一維數據

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=1,
            hidden_size=5,
            batch_first=True
        )
        self.out = nn.Linear(5, 3)

    def forward(self, x):
        r_out, h_state = self.lstm(x, None)
        # r_out, out_len = rnn_utils.pad_packed_sequence(r_out, batch_first=True)
        # x = self.out(r_out[:, -1, :])這里用r_out最后一個時間步的結果是不對的,因為此時又給補齊0了,直接用這個結果去訓練結果很差,因為有效的值太少了。
     x = self.out(h_n[-1, :, :]) #可以直接用h_n的最后一層的結果,與最后一個時刻的r_out結果相同,且是tensor,不用再用pad_packed轉換對齊了
return x net = LSTM() train_dataloader = Data.DataLoader(train_x, batch_size=2, collate_fn=collate_fn) loss_func = nn.CrossEntropyLoss() flag = 0 for data_x, data_y, length in train_dataloader: data = rnn_utils.pack_padded_sequence(data_x, length, batch_first=True) output= net(data.float()) if flag == 0: print(output.size()) print(output.shape) print(output) flag = 1

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM