关于Pytorch报警告:Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead


在使用Pytorch的时候,遇到警告的日志打印:

[W IndexingUtils.h:20] Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (function expandTensors)
[W ..\aten\src\ATen\native\cuda\LegacyDefinitions.cpp:55] Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (func
tion masked_scatter__cuda)

 

 

虽然不影响使用,但是由于满屏的警告,使得覆盖了重要的输出信息。

解决办法:

在StackOverflow上,看到有回答:

 

 在自己的代码中,找到 mask 将其 .byte( ) 改为 .bool( ) 即可。

mask = autograd.Variable(torch.zeros((a, b))).byte()  

改为:

mask = autograd.Variable(torch.zeros((a, b))).bool()  

一般问题都是出现在对 loss的自动求导上。

修改完后,发现训练速度都快很多:

Before:

 

 After:

 

 

可以看到,因为不需要每一步都打印警告日志,训练速度比之前快很多。




免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM