pytorch中y.data.norm()的含义


import torch
x = torch.randn(3, requires_grad=True)
y = x*2
print(y.data.norm())
print(torch.sqrt(torch.sum(torch.pow(y,2))))  #其实就是对y张量L2范数,先对y中每一项取平方,之后累加,最后取根号
i=0
while y.data.norm()<1000:
  y = y*2
  i+=1
print(y)
print(i)

结果:

tensor(3.7025)
tensor(3.7025, grad_fn=<SqrtBackward>)
tensor([ 1066.4563, -1511.3652,  -414.6933], grad_fn=<MulBackward0>)
9

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM