1. 正常情況下是1次forward 1次更新,代碼為:
optimizer.zerograd
y = model(x)
loss_mse = torch.MSE(x, y)
loss_mse.backward()
optimizer.step()
其實只需要加3行代碼
2. 當想要讓batchsize強行變大時(多次forward 1次更新),代碼為:
if i_count == 1:
optimizer.zerograd
y = model(x)
loss_mse = torch.MSE(x, y) %這幾行不變,相當於只是加了下面的3行
loss_mse.backward
if i_ite % batchsize_loss == 0:
optimizer.step()
optimizer.zerograd