import numpy as np X = np.array([[1, 2], [4, 5], [7, 8]]) print np.mean(X, axis=0, keepdims=True) print np.mean(X, axis=1, keepdims=True)
結果是分別是
[[ 1.5] [[ 4. 5.]] [ 4.5] [ 7.5]]
axis=0,那么輸出矩陣是1行,求每一列的平均(按照每一行去求平均);axis=1,輸出矩陣是1列,求每一行的平均(按照每一列去求平均)。還可以這么理解,axis是幾,那就表明哪一維度被壓縮成1。
實際上這個axis=0就是選擇shape中第一個元素(即第一維)變為1,axis=1就是選擇shape中第二個元素變為1。用shape來看會比較方便。
>>> x = np.array([[1, 2], [4, 5], [7, 8]]) >>> x.shape (3, 2) >>> y = np.mean(x, axis=0, keepdims=True) >>> y.shape (1, 2)
再舉個更復雜點的例子,比如我們輸入為batch = [128, 28, 28],可以理解為batch=128,圖片大小為28×28像素,我們相求這128個圖片的均值,應該這么寫
m = np.mean(batch, axis=0)
輸出結果m的shape為(1,28,28),就是這128個圖片在每一個像素點平均值。
不給出axis不是默認axis為0,而是把所有元素加起來求平均
>>> a = np.array([[1, 2], [3, 4]]) >>> np.mean(a) 2.5