均方誤差(Mean Squared Error)是度量模型性能的一種方法。
假設m是樣本集的總個數
是第i個樣本的預測值,
是第i個樣本的真實值。
pytorch中的均方誤差函數
torch.nn.functional.mse_loss(
input, # 預測
target, # 目標
)
代碼實現:
>>>import torch >>>import torch.nn.functional as F >>> x = torch.randn(5, 10) >>> w = torch.randn(10, 10) >>> logits = x @ w.t() >>> pred = torch.sigmoid(logits) # 預測值
>>> loss = F.mse_loss(pred, target) # 計算mse_loss