一个多分支输出网络(一个Encoder,多个Decoder)
我们期望每个分支的损失L_i分别对各自的参数进行优化,而共享参数部分采用各分支损失之和Sum(L_i)进行优化。
在pytorch中是默认支持这种操作的,也就是我们可以分别计算出各分支的loss,然后直接把他们相加即可。(参考上面pytorch论坛的资料)
但是需要注意的是:但是如果您想要为您的损失按顺序向后调用,您可能想要在.backward()调用中指定retain_graph=True,否则中间变量将被清除。
如果分别对每个损失调用backward(),不加retain_grath=True肯定会遇到这个错误(pytorch会自动报错)。更简单的方法是将损失相加。这两种方法是等价的,因为多个向后调用的梯度将被累加。
方式一:
optimizer = optim.SGD(params=my_params_list, lr=....) loss_func = nn.CrossEntropyLoss() loss1 = loss_func(y1, target1) loss2 = loss_func(y2. target2) loss = loss1 + loss2 optimizer.zero_grad() loss.backward() optimizer.step()
方式二:
optimizer = optim.SGD(params=my_params_list, lr=....) loss_func = nn.CrossEntropyLoss() loss1 = loss_func(y1, target1) loss2 = loss_func(y2. target2) optimizer.zero_grad() loss1.backward(retain_graph=True) loss2.backward() optimizer.step()
建议采用方式一进行计算,第一个是优雅和正确的。
## you can simply do: o1, o2 = mm(input) o = o1 + o2 # loss ## Or you can do l1 = loss(o1, target) l2 = loss2(o2, target2) torch.autograd.backward([l1, l2])
如果想不同的分支采用不同的优化器:
opt1 = optim.Adam(branch_1.parameters(), ...) opt2 = optim.SGD(branch_2.parameters(), ...) ... ... loss = 2*loss_1 + 3 *loss_2 loss.backward() opt1.step() opt2.step()
参考:
How to train multi-branch output network?
pytorch辅助损失函数反向传播的疑问?