pytorch自定義loss函數的幾種方法


1. 讓張量使用Variable類型,如下所示

1 from torch.autograd import Variable
2 
3 inp = torch.zeros(2, 3)
4 inp = Variable(inp).type(torch.LongTensor)
5 print(inp)

Variable類型包裝了Tensor類型,並提供了backward()接口

使用Variable類型的好處是,可以按照論文公式來直接使用,並在做張量運算之后,使用繼承的backward()直接進行反向傳播

2. 自定義類繼承nn.Module

1 class CustomMSELoss(nn.Module):
2     def __init__(self):
3         super().__init__()
4         
5     def forward(self, x, y):
6         return torch.mean(torch.pow((x - y), 2))

這種方法結構化程度高,在開發給用戶使用時,由於不知道用戶的Tensor是否是Variable類型,采用該方法可以減少問題。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM