Python numpy 浮點數精度問題
在復現FP(fictitious play, Iterative solution of games by fictitious play-page393)算法的時候,迭代到中間發現沒法復現paper里的結果,發現是numpy矩陣運算浮點數精度的問題。
- 具體問題
矩陣和向量相乘
\[\begin{pmatrix} 3 & 1 & 1 & 1 \end{pmatrix} \times \begin{pmatrix} 3 & 1.1 & 1.2 \\ 1.3 & 2 & 0 \\ 0 & 1 & 3.1 \\ 2 & 1.5 & 1.1 \end{pmatrix} = \begin{pmatrix} 12.3 & 7.8 & 7.8 \end{pmatrix} \]
然后取argmin
想得到第一個7.8的index,也就是1。但由於精度的問題,導致兩個7.8實際不一樣大,取到了第二個7.8的index。
具體問題代碼為
import numpy as np
x = np.matrix([3,1,1,1])*np.matrix([[3,1.1,1.2],[1.3,2,0],[0,1,3.1],[2,1.5,1.1]])
print('matrix: ',x)
print('value: ',x[0,0],x[0,1],x[0,2])
print('index: ',np.argmin(x))
得到
matrix: [[12.3 7.8 7.8]]
value: 12.3 7.800000000000001 7.799999999999999
index: 2
可以發現明明相同的兩個7.8由於精度變成了兩個大小不同的數,所以argmin
得到了2。
- 解決辦法
二進制固有的問題,只能自己手動近似,用保留小數點位數消除誤差。
如這里保留5位小數:
import numpy as np
x = np.round(np.matrix([3,1,1,1])*np.matrix([[3,1.1,1.2],[1.3,2,0],[0,1,3.1],[2,1.5,1.1]]),5)
print('matrix: ',x)
print('value: ',x[0,0],x[0,1],x[0,2])
print('index: ',np.argmin(x))
得到
matrix: [[12.3 7.8 7.8]]
value: 12.3 7.8 7.8
index: 1
- 注意事項
這個辦法不能解決所有問題,畢竟每個問題精度要求不一樣。但由於計算機二進制的原因,沒法從根本上解決,只能通過近似的方式,具體問題具體解決。