pytorch的nn.MSELoss損失函數


MSE是mean squared error的縮寫,即平均平方誤差,簡稱均方誤差。

MSE是逐元素計算的,計算公式為:

舊版的nn.MSELoss()函數有reduce、size_average兩個參數,新版的只有一個reduction參數了,功能是一樣的。reduction的意思是維度要不要縮減,以及怎么縮減,有三個選項:

  • 'none': no reduction will be applied.
  • 'mean': the sum of the output will be divided by the number of elements in the output.
  • 'sum': the output will be summed.

如果不設置reduction參數,默認是'mean'。

程序示例: 

import torch
import torch.nn as nn

a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float)

loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a.float(), b.float())
print(loss1)   # 輸出結果:tensor([[ 4.,  9.],
               #                 [25.,  4.]])

loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a.float(), b.float())
print(loss2)   # 輸出結果:tensor(42.)


loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a.float(), b.float())
print(loss3)   # 輸出結果:tensor(10.5000)

 

對於三維的輸入也是一樣的:

a = torch.randint(0, 9, (2, 2, 3)).float()
b = torch.randint(0, 9, (2, 2, 3)).float()
print('a:\n', a)
print('b:\n', b)

loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a.float(), b.float())
print('loss_none:\n', loss1)

loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a.float(), b.float())
print('loss_sum:\n', loss2)


loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a.float(), b.float())
print('loss_mean:\n', loss3)

運行結果:

 

 參考資料:

pytorch的nn.MSELoss損失函數

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM