axis
先看懂numpy.argmax的含義.那么numpy.sum就非常好理解.
看一維的例子.
import numpy as np
a = np.array([1, 5, 5, 2])
print(np.sum(a, axis=0))
上面代碼就是把各個值加相加.默認axis為0.axis在二維以上數組中才能體現出來作用.
import numpy as np
a = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.sum(a, axis=0))
為了描述方便,a就表示這個二維數組,np.sum(a, axis=0)的含義是a[0][j],a[1][j],a[2]j對應項相加的結果.即[1,5,5,2]+[9,6,2,8]+[3,7,9,1]=[13,18,16,11].接着看axis=1的情況.
import numpy as np
a = np.array([[1, 5, 5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.sum(a, axis=1))
np.sum(a, axis=1)的含義是a[i][0],a[i][1],a[i][2],a[i]3對應項相加的結果.即[1,9,3]+[5,6,7]+[5,2,9]+[2,8,1]=[13,25,20].
三維情況是類似的.
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]
])
print(np.sum(a, axis=0))
np.sum(a, axis=0)的含義是a[0][j][k],a[1][j][k] (j=0,1,2,k=0,1,2,3)中對應項相加的結果.[[1, 5, 5, 2],[9, -6, 2, 8],[-3, 7, -9, 1]]+[[-1, 5, -5, 2],[9, 6, 2, 8],[3, 7, 9, 1]]=[[0,10,0,4],[18,0,4,16],[0,14,0,2]]. axis=1,axis=2的道理是類似的.
keepdims
keepdims的含義是是否保持維數,默認是false.通過上面的例子可以發現sum之后3維變成2維.2維變成1維.keepdims=True,最直觀的理解就是把sum結果又加一個[],以保持它的維度不變.這在某些場景有非常有用.
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]
])
print(np.sum(a, axis=0, keepdims=True))
可以和上面的例子對比下結果.