一開始寫這篇隨筆的時候還沒有了解到 Dateloader有一個 collate_fn 的參數,通過定義一個collate_fn 函數,其實很多batch補齊到當前batch最長的操作可以放在collate_fn 里面去,這樣代碼在訓練和模型中就可以更加簡潔。有時間再整理一下這個吧。
_________________________________________
使用的主要部分包括:Dateset、 Dateloader、MSELoss、PackedSequence、pack_padded_sequence、pad_packed_sequence
模型包含LSTM模塊。
參考了下面兩篇博文,總結了一下。對PackedSequence相關的理解可以先看這兩篇。本文主要是把這些應用從數據准備到loss計算都串起來大致提供了一下代碼思路,權當給自己的提醒備份吧。或者看完下面兩篇,但是不知道具體怎么操作的朋友們一個參考。
http://www.cnblogs.com/lindaxin/p/8052043.html#commentform
https://blog.csdn.net/lssc4205/article/details/79474735
使用Dateset構建數據集的時候,在__getitem__函數中把所有數據先補齊到 全局最長序列的長度。
def __getitem__(self, index): ''' get original data 此處省略獲取原始數據的代碼 input_data,output_data 數據shape是 seq_length * feature_dim ''' # 當前seq_length小於所有數據中的最長數據長度,則補0到同一長度。 ori_length = input_data.shape[0] if ori_length < self.max_len: npi = np.zeros(self.input_feature_dim, dtype=np.float32) npi = np.tile(npi, (self.max_len - ori_length,1)) input_data = np.row_stack((input_data, npi)) npo = np.zeros(self.output_feature_dim, dtype=np.float32) npo = np.tile(npo, (self.max_len - ori_length,1)) output_data = np.row_stack((output_data, npo)) return input_data, output_data, ori_length, input_data_path
在模型中,forward的實現中,需要在LSTM之前使用pack_padded_sequence、在LSTM之后使用pad_packed_sequence,中間還涉及到順序的還原之類的操作。
def forward(self, input_x, length_list, hidden=None): if hidden is None: # 這里沒用 配置中的batch_size,而是直接在input_x中取batch_size是為了防止last_batch的batch_size不是配置中的那個,引發bug h_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float() c_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float() else: h_0, c_0 = hidden ''' 省略模型其他部分,直接進去LSTM前后的操作 ''' _, idx_sort = torch.sort(length_list, dim=0, descending=True) _, idx_unsort = otrch.sort(idx_sort, dim=0) input_x = input_x.index_select(0, Variable(idx_sort)) length_list = list(length_list[idx_sort]) pack = nn_utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=self.batch_first) output, hidden = self.BiLSTM(pack, (h0, c0)) un_padded = nn_utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) un_padded = un_padded[0].index_select(0, Variable(idx_unsort)) # 此時的un_padded已經完成了還原,並且補0完成,而且這時的補0到的序列長度是當前batch的最長長度,而不是Dateset中的全局最長長度!
# 所以在main train函數中也要對label的seq做處理 return un_padded
main train中,要對label做相應的截斷處理,因為模型返回的長度已經是補齊到當前batch的最長序列長度了,而dateset返回的label是補齊到全局最長序列長度。算loss的時候,MSELoss的reduce參數要設置成false,讓loss函數返回一個loss矩陣,再構造一個01掩膜矩陣mask,矩陣相乘求和得到真的loss(達到填充0的位置不參與loss的目的)
def train(**kwargs):
train_data = my_dataset()
train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
model = getattr(models, opt.model)(batchsize=opt.batch_size)
criterion = torch.nn.MSELoss(reduce=False)
lr = opt.lf
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
for epoch in range(opt.start_epoch, opt.max_epoch):
for ii, (data, label, length_list,_) in tqdm(enumerate(train_dataloader)):
cur_batch_max_len = length_list.max()
data = Variable(data)
target = Variable(label)
optimizer.zero_grad()
score = model(data, length_list)
loss_mat = criterion(score, target)
list_int = list(length_list)
mask_mat = Variable(t.ones(len(list_int),cur_batch_max_len,opt.output_feature_dim))
num_element = 0
for idx_sample in range(len(list_int)):
num_element += list_int[idx_sample] * opt.output_feature_dim
if list_int[idx_sample] != cur_batch_max_len:
mask_mat[idx_sample, list[idx_sample]:] = 0.0
loss = (loss_mat * mask_mat).sum() / num_element
loss.backward()
optimizer.step()