pytorch中多个loss回传的参数影响示例


写了一段代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.fc1 = nn.Linear(5, 4)
        self.fc2 = nn.Linear(4, 3)
        self.fc3 = nn.Linear(4, 3)

    def forward(self, x):
        mid = self.fc1(x)
        out1 = self.fc2(mid)
        out2 = self.fc3(mid)
        return out1, out2


x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001)

print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
    out1, out2 = model(x)
    loss1 = F.cross_entropy(out1, y)
    loss2 = F.cross_entropy(out2, y)
    loss = loss1 + loss2
    optim.zero_grad()
    loss.backward()
    optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)

在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。

得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。

(粗略分析,抛砖引玉)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM