pytorch在損失函數中為權重添加L1正則化


L1正則化可以使權重變稀疏,應用場景:對one-hot詞袋模型中的詞表進行裁剪時,根據權重weight篩選,此時需要權重越稀疏越好;

L1_Weight為超參數,可設定為1e-4

 1 def train(model, iterator, optimizer, criteon):
 2     avg_acc, avg_loss = [], []
 3     model.train()     
 4 
 5     for batch in tqdm(iterator):
 6         text, label = batch[0].cuda(), batch[1].cuda()         
 7 
 8         pred = model(text)    
 9         l1_penalty = L1_Weight * sum([p.abs().sum() for p in model.fc.parameters()])
10         loss = criteon(pred, label.long())        
11         loss_with_penalty = loss + l1_penalty
12 
13         acc = utils.binary_acc(torch.argmax(pred.cpu(), dim=1), label.cpu().long())  
14         avg_acc.append(acc)
15         avg_loss.append(loss.item())
16 
17         optimizer.zero_grad()
18         loss_with_penalty.backward()
19         #loss.backward()
20         optimizer.step()
21 
22     avg_acc = np.array(avg_acc).mean()
23     avg_loss = np.array(avg_loss).mean()
24     train_metrics = {'train_acc': avg_acc,
25                      'train_loss': avg_loss
26                      }
27     logging.info(train_metrics)
28     return avg_acc, avg_loss

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM