1. 设置打印精
Pytorch中tensor打印的数据长度需要使用torch.set_printoptions(precision=xx)
进行设置,否则打印的长度会很短,给人一种精度不够的错觉:
>>> import torch >>> a=torch.tensor([1/3]) >>> a tensor([0.3333]) >>> # 修改打印精度为20位小数 >>> torch.set_printoptions(precision=20) >>> a tensor([0.33333334326744079590])
2. 类型转换对精度的影响
这里考虑使用类型转换将单精度浮点转换为双精度浮点:
>>> # 将单精度浮点转换为双精度浮点 >>> c=a.double() >>> c tensor([0.33333334326744079590], dtype=torch.float64)
可以看到,使用类型转换并不会提升数据精度
3. 重新定义高精度数据类型
那么,重新定义一个双精度的浮点数会怎么样呢?
>>> # 使用双精度浮点类型重新生成 >>> b=torch.tensor([1/3],dtype=torch.double) >>> b tensor([0.33333333333333331483], dtype=torch.float64)
4. 数据整体精度是否变化
此时,将数据加上100,可以看到小数后的精度变低了,但是数据整体精度保持不变:
1 >>> # 测试精度位数变化情况 2 >>> d=100+b 3 >>> d 4 tensor([100.33333333333332859638], dtype=torch.float64)
5. 建议
使用torch.set_default_dtype(torch.double)设置默认的数据类型为双精度浮点,使用torch.set_default_tensor_type(torch.DoubleTensor)在设置默认数据类型的同时会设置torch.tensor接口的默认类型。[3][4]对单精度浮点,为稳妥起见,根据输入内容设置打印精度为6位有效数字,同样对双精度浮点,根据输入内容设置设置打印精度为16位有效数字。这样,打印出来的值就是较为精确的值了。
参考: