pytorch 踩坑筆記之w.grad.data.zero_()


  在使用pytorch實現多項線性回歸中,在grad更新時,每一次運算后都需要將上一次的梯度記錄清空,運用如下方法:

     w.grad.data.zero_()
     b.grad.data.zero_() 

   但是,運行程序就會報如下錯誤:

  報錯,grad沒有data這個屬性,

  原因是,在系統將w的grad值初始化為none,第一次求梯度計算是在none值上進行報錯,自然會沒有data屬性

  修改方法:添加一個判斷語句,從第二次循環開始執行求導運算

for i in range(100):
    y_pred = multi_linear(x_train)
    loss = getloss(y_pred,y_train)
    if i != 0: w.grad.data.zero_() b.grad.data.zero_()
    loss.backward()
    w.data = w.data - 0.001 * w.grad.data
    b.data = b.data - 0.001 * b.grad.data

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM