1. 讓張量使用Variable類型,如下所示
1 from torch.autograd import Variable 2 3 inp = torch.zeros(2, 3) 4 inp = Variable(inp).type(torch.LongTensor) 5 print(inp)
Variable類型包裝了Tensor類型,並提供了backward()接口
使用Variable類型的好處是,可以按照論文公式來直接使用,並在做張量運算之后,使用繼承的backward()直接進行反向傳播
2. 自定義類繼承nn.Module
1 class CustomMSELoss(nn.Module): 2 def __init__(self): 3 super().__init__() 4 5 def forward(self, x, y): 6 return torch.mean(torch.pow((x - y), 2))
這種方法結構化程度高,在開發給用戶使用時,由於不知道用戶的Tensor是否是Variable類型,采用該方法可以減少問題。