一個多分支輸出網絡(一個Encoder,多個Decoder)
我們期望每個分支的損失L_i分別對各自的參數進行優化,而共享參數部分采用各分支損失之和Sum(L_i)進行優化。
在pytorch中是默認支持這種操作的,也就是我們可以分別計算出各分支的loss,然后直接把他們相加即可。(參考上面pytorch論壇的資料)
但是需要注意的是:但是如果您想要為您的損失按順序向后調用,您可能想要在.backward()調用中指定retain_graph=True,否則中間變量將被清除。
如果分別對每個損失調用backward(),不加retain_grath=True肯定會遇到這個錯誤(pytorch會自動報錯)。更簡單的方法是將損失相加。這兩種方法是等價的,因為多個向后調用的梯度將被累加。
方式一:
optimizer = optim.SGD(params=my_params_list, lr=....) loss_func = nn.CrossEntropyLoss() loss1 = loss_func(y1, target1) loss2 = loss_func(y2. target2) loss = loss1 + loss2 optimizer.zero_grad() loss.backward() optimizer.step()
方式二:
optimizer = optim.SGD(params=my_params_list, lr=....) loss_func = nn.CrossEntropyLoss() loss1 = loss_func(y1, target1) loss2 = loss_func(y2. target2) optimizer.zero_grad() loss1.backward(retain_graph=True) loss2.backward() optimizer.step()
建議采用方式一進行計算,第一個是優雅和正確的。
## you can simply do: o1, o2 = mm(input) o = o1 + o2 # loss ## Or you can do l1 = loss(o1, target) l2 = loss2(o2, target2) torch.autograd.backward([l1, l2])
如果想不同的分支采用不同的優化器:
opt1 = optim.Adam(branch_1.parameters(), ...) opt2 = optim.SGD(branch_2.parameters(), ...) ... ... loss = 2*loss_1 + 3 *loss_2 loss.backward() opt1.step() opt2.step()
參考:
How to train multi-branch output network?
pytorch輔助損失函數反向傳播的疑問?