Pytorch 解决lstm输入可变长序列问题


我是做图像到语音音素识别相关的科研工作的,需要用到lstm识别一个序列帧所对应的音素,但是不同音素有不同长度的帧,所以需要解决变长序列的问题。

需要解决这个问题的原因是:不等长的序列无法进行batch

我主要参考的文章:https://blog.csdn.net/kejizuiqianfang/article/details/100835528

但我又不能完全套用这篇文章的方法。这篇文章举例用的是无标签的数据集,但我们很可能处理有便签的数据集,所以这篇文章主要说明如何处理有标签的数据集

 

主要使用的也是这三个方法:

torch.nn.utils.rnn.pad_sequence()  把一个batch列表(注意必须是列表,不能是元组)中不等长的tensor补充成等长的tensor后返回batch的tensor

torch.nn.utils.rnn.pack_padded_sequence()  把pad后的tensor压缩成各个序列的实际长度,同时数据变成PackedSequence类型

torch.nn.utils.rnn.pad_packed_sequence()  把上面所压缩成PackedSequence的数据还原成tensor类型, 并补成等长的数据

注意不要忘记 import torch.nn.utils.rnn as rnn_utils

 

这次我举一个有标签数据的例子

train_x = [(torch.tensor([1, 2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([3, 4, 5, 6, 7]), torch.tensor(1)),
            (torch.tensor([4, 5, 6, 7]), torch.tensor(1)),
             (torch.tensor([5, 6, 7]), torch.tensor(1)),
              (torch.tensor([6, 7]), torch.tensor(1)),
               (torch.tensor([7]), torch.tensor(1))]

这里的train_x是已经经过class myDataset(Data.Dataset)返回的数据格式。我们需要注意的是train_x这个列表中是一个一个元组,元组的第一个元素是序列tensor,第二个元素是标签tensor。
物品参考的文章中与这不同的是列表里存的数一个一个序列tensor。
如果按照上文那样pad会报一个错误:

 

 所以,在处理collate_fn函数时需要注意把元组中的序列和标签分别取出来组成list。注意在排序的时候要对train_data排序,主要要和标签一起排

def collate_fn(train_data):
    # print("排序前:", train_data)
    train_data.sort(key=lambda data: len(data[0]), reverse=True)
    # print("排序后:", train_data)
    train_x = []
    train_y = []
    for data in train_data:
        train_x.append(data[0])
        train_y.append(data[1])

    data_length = [len(data) for data in train_x]
    print(data_length)
    train_x = rnn_utils.pad_sequence(train_x, batch_first=True, padding_value=0)
    train_y = torch.from_numpy(np.asarray(train_y))
    return train_x.unsqueeze(-1), train_y, data_length 

 

在输入到lstm之前,需要进行压缩,提高效率

data = rnn_utils.pack_padded_sequence(data_x, length, batch_first=True)
    output= net(data.float())

 

如果有全连接层,需要在进入全连接层之前,进行pad_packed_sequence

下面是整体的代码:

import torch.nn.utils.rnn as rnn_utils
import torch
import torch.utils.data as Data
import torch.nn as nn
import numpy as np

train_x = [(torch.tensor([1, 2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([2, 3, 4, 5, 6, 7]), torch.tensor(1)),
           (torch.tensor([3, 4, 5, 6, 7]), torch.tensor(1)),
            (torch.tensor([4, 5, 6, 7]), torch.tensor(1)),
             (torch.tensor([5, 6, 7]), torch.tensor(1)),
              (torch.tensor([6, 7]), torch.tensor(1)),
               (torch.tensor([7]), torch.tensor(1))]

def collate_fn(train_data):
    train_x.sort(key=lambda data: len(data[0]), reverse=True)
   train_x = []
    train_y = []
    for data in train_data:
        train_x.append(data[0])
        train_y.append(data[1])

    data_length = [len(data) for data in train_x]
    # print(data_length)
    train_x = rnn_utils.pad_sequence(train_x, batch_first=True, padding_value=0)
    train_y = torch.from_numpy(np.asarray(train_y))
    return train_x.unsqueeze(-1), train_y, data_length  # 对train_data增加了一维数据

class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=1,
            hidden_size=5,
            batch_first=True
        )
        self.out = nn.Linear(5, 3)

    def forward(self, x):
        r_out, h_state = self.lstm(x, None)
        # r_out, out_len = rnn_utils.pad_packed_sequence(r_out, batch_first=True)
        # x = self.out(r_out[:, -1, :])这里用r_out最后一个时间步的结果是不对的,因为此时又给补齐0了,直接用这个结果去训练结果很差,因为有效的值太少了。
     x = self.out(h_n[-1, :, :]) #可以直接用h_n的最后一层的结果,与最后一个时刻的r_out结果相同,且是tensor,不用再用pad_packed转换对齐了
return x net = LSTM() train_dataloader = Data.DataLoader(train_x, batch_size=2, collate_fn=collate_fn) loss_func = nn.CrossEntropyLoss() flag = 0 for data_x, data_y, length in train_dataloader: data = rnn_utils.pack_padded_sequence(data_x, length, batch_first=True) output= net(data.float()) if flag == 0: print(output.size()) print(output.shape) print(output) flag = 1

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM