pytorch強行讓batchsize變大(多次forward 1次更新)


1. 正常情況下是1次forward 1次更新,代碼為:

optimizer.zerograd

y = model(x)

loss_mse = torch.MSE(x, y)

loss_mse.backward()

optimizer.step()

 

其實只需要加3行代碼

2. 當想要讓batchsize強行變大時(多次forward 1次更新),代碼為:

if i_count == 1:

  optimizer.zerograd

y = model(x)                        

loss_mse = torch.MSE(x, y)        %這幾行不變,相當於只是加了下面的3行

loss_mse.backward

if i_ite % batchsize_loss == 0:

  optimizer.step()

  optimizer.zerograd

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM