Pytorch中tensor的打印精度


1. 設置打印精

Pytorch中tensor打印的數據長度需要使用torch.set_printoptions(precision=xx)進行設置,否則打印的長度會很短,給人一種精度不夠的錯覺:

>>> import torch
>>> a=torch.tensor([1/3])
>>> a
tensor([0.3333])
>>> # 修改打印精度為20位小數
>>> torch.set_printoptions(precision=20)
>>> a
tensor([0.33333334326744079590])

2. 類型轉換對精度的影響

這里考慮使用類型轉換將單精度浮點轉換為雙精度浮點:

>>> # 將單精度浮點轉換為雙精度浮點
>>> c=a.double()
>>> c
tensor([0.33333334326744079590], dtype=torch.float64)

可以看到,使用類型轉換並不會提升數據精度

3. 重新定義高精度數據類型

那么,重新定義一個雙精度的浮點數會怎么樣呢?

>>> # 使用雙精度浮點類型重新生成
>>> b=torch.tensor([1/3],dtype=torch.double)
>>> b
tensor([0.33333333333333331483], dtype=torch.float64)

4. 數據整體精度是否變化

此時,將數據加上100,可以看到小數后的精度變低了,但是數據整體精度保持不變:

 

1 >>> # 測試精度位數變化情況
2 >>> d=100+b
3 >>> d
4 tensor([100.33333333333332859638], dtype=torch.float64)

 

5. 建議

使用torch.set_default_dtype(torch.double)設置默認的數據類型為雙精度浮點,使用torch.set_default_tensor_type(torch.DoubleTensor)在設置默認數據類型的同時會設置torch.tensor接口的默認類型。[3][4]對單精度浮點,為穩妥起見,根據輸入內容設置打印精度為6位有效數字,同樣對雙精度浮點,根據輸入內容設置設置打印精度為16位有效數字。這樣,打印出來的值就是較為精確的值了。

 

參考:

Pytorch中tensor的打印精度_步子大了吧的博客-CSDN博客_pytorch 設置精度

python與pytorch的數據類型、數據精度與轉換_hangyangSJTU的博客-CSDN博客_torch 精度


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM