1. 設置打印精
Pytorch中tensor打印的數據長度需要使用torch.set_printoptions(precision=xx)
進行設置,否則打印的長度會很短,給人一種精度不夠的錯覺:
>>> import torch >>> a=torch.tensor([1/3]) >>> a tensor([0.3333]) >>> # 修改打印精度為20位小數 >>> torch.set_printoptions(precision=20) >>> a tensor([0.33333334326744079590])
2. 類型轉換對精度的影響
這里考慮使用類型轉換將單精度浮點轉換為雙精度浮點:
>>> # 將單精度浮點轉換為雙精度浮點 >>> c=a.double() >>> c tensor([0.33333334326744079590], dtype=torch.float64)
可以看到,使用類型轉換並不會提升數據精度
3. 重新定義高精度數據類型
那么,重新定義一個雙精度的浮點數會怎么樣呢?
>>> # 使用雙精度浮點類型重新生成 >>> b=torch.tensor([1/3],dtype=torch.double) >>> b tensor([0.33333333333333331483], dtype=torch.float64)
4. 數據整體精度是否變化
此時,將數據加上100,可以看到小數后的精度變低了,但是數據整體精度保持不變:
1 >>> # 測試精度位數變化情況 2 >>> d=100+b 3 >>> d 4 tensor([100.33333333333332859638], dtype=torch.float64)
5. 建議
使用torch.set_default_dtype(torch.double)設置默認的數據類型為雙精度浮點,使用torch.set_default_tensor_type(torch.DoubleTensor)在設置默認數據類型的同時會設置torch.tensor接口的默認類型。[3][4]對單精度浮點,為穩妥起見,根據輸入內容設置打印精度為6位有效數字,同樣對雙精度浮點,根據輸入內容設置設置打印精度為16位有效數字。這樣,打印出來的值就是較為精確的值了。
參考: