L1正則化可以使權重變稀疏,應用場景:對one-hot詞袋模型中的詞表進行裁剪時,根據權重weight篩選,此時需要權重越稀疏越好;
L1_Weight為超參數,可設定為1e-4
1 def train(model, iterator, optimizer, criteon): 2 avg_acc, avg_loss = [], [] 3 model.train() 4 5 for batch in tqdm(iterator): 6 text, label = batch[0].cuda(), batch[1].cuda() 7 8 pred = model(text) 9 l1_penalty = L1_Weight * sum([p.abs().sum() for p in model.fc.parameters()]) 10 loss = criteon(pred, label.long()) 11 loss_with_penalty = loss + l1_penalty 12 13 acc = utils.binary_acc(torch.argmax(pred.cpu(), dim=1), label.cpu().long()) 14 avg_acc.append(acc) 15 avg_loss.append(loss.item()) 16 17 optimizer.zero_grad() 18 loss_with_penalty.backward() 19 #loss.backward() 20 optimizer.step() 21 22 avg_acc = np.array(avg_acc).mean() 23 avg_loss = np.array(avg_loss).mean() 24 train_metrics = {'train_acc': avg_acc, 25 'train_loss': avg_loss 26 } 27 logging.info(train_metrics) 28 return avg_acc, avg_loss
