關於numpy mean函數的axis參數


 

import numpy as np
X = np.array([[1, 2], [4, 5], [7, 8]])
print np.mean(X, axis=0, keepdims=True)
print np.mean(X, axis=1, keepdims=True)

結果是分別是

            [[ 1.5]
 [[ 4.  5.]]      [ 4.5]    
                  [ 7.5]]

axis=0,那么輸出矩陣是1行,求每一列的平均(按照每一行去求平均);axis=1,輸出矩陣是1列,求每一行的平均(按照每一列去求平均)。還可以這么理解,axis是幾,那就表明哪一維度被壓縮成1。

實際上這個axis=0就是選擇shape中第一個元素(即第一維)變為1,axis=1就是選擇shape中第二個元素變為1。用shape來看會比較方便。

>>> x = np.array([[1, 2], [4, 5], [7, 8]])
>>> x.shape
(3, 2)
>>> y = np.mean(x, axis=0, keepdims=True)
>>> y.shape
(1, 2)

再舉個更復雜點的例子,比如我們輸入為batch = [128, 28, 28],可以理解為batch=128,圖片大小為28×28像素,我們相求這128個圖片的均值,應該這么寫

m = np.mean(batch, axis=0)

輸出結果m的shape為(1,28,28),就是這128個圖片在每一個像素點平均值。

 

 

不給出axis不是默認axis為0,而是把所有元素加起來求平均

    >>> a = np.array([[1, 2], [3, 4]])
    >>> np.mean(a)
    2.5

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM