我們用以下方法計算百萬以上float型數據的標准偏差,以估計各個方法的計算性能:
- 原始python
- numpy
- cython
- c(由cython調用)
python 原始方法:
1 # File: StdDev.py 2 3 import math 4 5 def pyStdDev(a): 6 mean = sum(a) / len(a) 7 return math.sqrt((sum(((x - mean)**2 for x in a)) / len(a)))
引入numpy對象:
1 # File: StdDev.py 2 3 import numpy as np 4 5 def npStdDev(a): 6 return np.std(a)
簡單cython代碼:
# File: cyStdDev.pyx import math def cyStdDev(a): m = a.mean() w = a - m wSq = w**2 return math.sqrt(wSq.mean())
numpy優化后的cython:
# File: cyStdDev.pyx cdef extern from "math.h": double sqrt(double m) from numpy cimport ndarray cimport numpy as np cimport cython @cython.boundscheck(False) def cyOptStdDev(ndarray[np.float64_t, ndim=1] a not None): cdef Py_ssize_t i cdef Py_ssize_t n = a.shape[0] cdef double m = 0.0 for i in range(n): m += a[i] m /= n cdef double v = 0.0 for i in range(n): v += (a[i] - m)**2 return sqrt(v / n)
最后cython調用”c”代碼:
# File: cyStdDev.pyx cdef extern from "std_dev.h": double std_dev(double *arr, size_t siz) def cStdDev(ndarray[np.float64_t, ndim=1] a not None): return std_dev(<double*> a.data, a.size)
“c”代碼定義在“std_dev.h”:
1 #include <stdlib.h> 2 double std_dev(double *arr, size_t siz);
在“std_dev.c”實現:
#include <math.h> #include "std_dev.h" double std_dev(double *arr, size_t siz) { double mean = 0.0; double sum_sq; double *pVal; double diff; double ret; pVal = arr; for (size_t i = 0; i < siz; ++i, ++pVal) { mean += *pVal; } mean /= siz; pVal = arr; sum_sq = 0.0; for (size_t i = 0; i < siz; ++i, ++pVal) { diff = *pVal - mean; sum_sq += diff * diff; } return sqrt(sum_sq / siz); }
分別測量其運行時間:
# Pure Python python3 -m timeit -s "import StdDev; import numpy as np; a = [float(v) for v in range(1000000)]" "StdDev.pyStdDev(a)" # Numpy python3 -m timeit -s "import StdDev; import numpy as np; a = np.arange(1e6)" "StdDev.npStdDev(a)" # Cython - naive python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cyStdDev(a)" # Optimised Cython python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cyOptStdDev(a)" # Cython calling C python3 -m timeit -s "import cyStdDev; import numpy as np; a = np.arange(1e6)" "cyStdDev.cStdDev(a)"
結果:
方法 | 運行時間(ms) | python做基准 | numpy做基准 |
python | 183 | 1倍 | 0.03倍 |
numpy | 5.97 | 31 | 1 |
cython | 7.76 | 24 | 0.8 |
cython + numpy | 2.18 | 84 | 2.7 |
調用c | 2.22 | 82 | 2.7 |
總結:
- numpy優化速度很高,相比於python
- cython 在非優化狀態下居然跟numpy性能差不多,優秀
- 直接手寫c語言是性能很高的,但還是不如cython+numpy,大爺還是厲害
=============================================
qsy 23 may 2019