pytorch中y.data.norm()的含義


import torch
x = torch.randn(3, requires_grad=True)
y = x*2
print(y.data.norm())
print(torch.sqrt(torch.sum(torch.pow(y,2))))  #其實就是對y張量L2范數,先對y中每一項取平方,之后累加,最后取根號
i=0
while y.data.norm()<1000:
  y = y*2
  i+=1
print(y)
print(i)

結果:

tensor(3.7025)
tensor(3.7025, grad_fn=<SqrtBackward>)
tensor([ 1066.4563, -1511.3652,  -414.6933], grad_fn=<MulBackward0>)
9

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM