pytorch中多個loss回傳的參數影響示例


寫了一段代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.fc1 = nn.Linear(5, 4)
        self.fc2 = nn.Linear(4, 3)
        self.fc3 = nn.Linear(4, 3)

    def forward(self, x):
        mid = self.fc1(x)
        out1 = self.fc2(mid)
        out2 = self.fc3(mid)
        return out1, out2


x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001)

print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
    out1, out2 = model(x)
    loss1 = F.cross_entropy(out1, y)
    loss2 = F.cross_entropy(out2, y)
    loss = loss1 + loss2
    optim.zero_grad()
    loss.backward()
    optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)

在loss.backward()處分別更換為loss1.backward()和loss2.backward(),觀察fc2和fc3層的參數變化。

得出的結論為:loss2只影響fc3的參數,loss1只影響fc2的參數。

(粗略分析,拋磚引玉)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM