numpy sum axis詳解


axis

先看懂numpy.argmax的含義.那么numpy.sum就非常好理解.
看一維的例子.

import numpy as np
a = np.array([1, 5, 5, 2])
print(np.sum(a, axis=0))

上面代碼就是把各個值加相加.默認axis為0.axis在二維以上數組中才能體現出來作用.

import numpy as np
a = np.array([[1, 5, 5, 2],
              [9, 6, 2, 8],
              [3, 7, 9, 1]])
print(np.sum(a, axis=0))

為了描述方便,a就表示這個二維數組,np.sum(a, axis=0)的含義是a[0][j],a[1][j],a[2]j對應項相加的結果.即[1,5,5,2]+[9,6,2,8]+[3,7,9,1]=[13,18,16,11].接着看axis=1的情況.

import numpy as np
a = np.array([[1, 5, 5, 2],
              [9, 6, 2, 8],
              [3, 7, 9, 1]])
print(np.sum(a, axis=1))

np.sum(a, axis=1)的含義是a[i][0],a[i][1],a[i][2],a[i]3對應項相加的結果.即[1,9,3]+[5,6,7]+[5,2,9]+[2,8,1]=[13,25,20].
三維情況是類似的.

import numpy as np
a = np.array([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],

              [
                  [-1, 5, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]
            ])
print(np.sum(a, axis=0))

np.sum(a, axis=0)的含義是a[0][j][k],a[1][j][k] (j=0,1,2,k=0,1,2,3)中對應項相加的結果.[[1, 5, 5, 2],[9, -6, 2, 8],[-3, 7, -9, 1]]+[[-1, 5, -5, 2],[9, 6, 2, 8],[3, 7, 9, 1]]=[[0,10,0,4],[18,0,4,16],[0,14,0,2]]. axis=1,axis=2的道理是類似的.

keepdims

keepdims的含義是是否保持維數,默認是false.通過上面的例子可以發現sum之后3維變成2維.2維變成1維.keepdims=True,最直觀的理解就是把sum結果又加一個[],以保持它的維度不變.這在某些場景有非常有用.

import numpy as np
a = np.array([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],

              [
                  [-1, 5, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ]
            ])
print(np.sum(a, axis=0, keepdims=True))

可以和上面的例子對比下結果.

參考資料

numpy官方文檔


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM