關於Pytorch報警告:Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead


在使用Pytorch的時候,遇到警告的日志打印:

[W IndexingUtils.h:20] Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (function expandTensors)
[W ..\aten\src\ATen\native\cuda\LegacyDefinitions.cpp:55] Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (func
tion masked_scatter__cuda)

 

 

雖然不影響使用,但是由於滿屏的警告,使得覆蓋了重要的輸出信息。

解決辦法:

在StackOverflow上,看到有回答:

 

 在自己的代碼中,找到 mask 將其 .byte( ) 改為 .bool( ) 即可。

mask = autograd.Variable(torch.zeros((a, b))).byte()  

改為:

mask = autograd.Variable(torch.zeros((a, b))).bool()  

一般問題都是出現在對 loss的自動求導上。

修改完后,發現訓練速度都快很多:

Before:

 

 After:

 

 

可以看到,因為不需要每一步都打印警告日志,訓練速度比之前快很多。




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM