numpy 中的 broadcasting 理解


broadcast 是 numpy 中 array 的一個重要操作。

 

首先,broadcast 只適用於加減。

 

然后,broadcast 執行的時候,如果兩個 array 的 shape 不一樣,會先給“短”的那一個,增加高維度“擴展”(broadcasting),比如,一個 2 維的 array,可以是一個 3 維 size 為 1 的 3維 array。

類似於: shape(1,3,2) = shape(3,2)

 

最后,比較兩個 array(擴展后的),按照 dimension 從低到高,比較每一個維度的 size 是否滿足下面兩個條件之一:

1. 相等

2. 其中一個為 1

 

所以,舉例,下列 array 是否可以進行 broadcast:

1. shape(4, 3) 與 shape(3,) :shape(3) 可以 broadcast 為 shape(1, 3),那么,從低到高: d0(3 === 3), d1(其中一個為 1)。結論,可以,結果的為 shape(4, 3)

2. shape(6,5,4,3, 與 shape(5, 4, 3):shape(5, 4, 3) 可以 broadcast 為 shape(1,5,4,3),那么,從低到高:d0( 3 === 3), d1(4 === 4), d2(5===5),d3(其中一個為 1)。結論,可以,結果為 shape(6, 5, 4, 3)。

3. shape(2,3) 與 shape(5,4,3):shape(2,3) 可以 broadcast 為 shape(1, 2, 3),那么,從低到高:d0( 3 == 3), d1(4!=2)。結論,不能進行 broadcast。

4. shape(4,1) 與 shape(5):shape(5)可以 broadcast 為 shape(1,5),那么,從低到高: d0( 其中一個為 1), d1(其中一個為 1)。結論,可以進行 broadcast,結果為 shape(4, 5) 。

broadcast 之后的運算是怎樣呢?舉例說明:

a = [ [0,1,2,3], [4,5,6,7] ]

b = [1,2,3,4]

a + b = [ [1,3,5,7], [5,7,9,11] ]

 或可自己運行下面代碼觀察

import numpy as np

a = np.arange(12)
b = a.reshape(3,2,2)

c = np.arange(4)
d = c.reshape(2, 2)

e = np.arange(2)

print d+b

print e+b

還有下面一種特殊情況,即擴展低維度為 1 的情況下:

import numpy as np 

a = np.arange(3)

b = np.arange(5)

a = a[:, np.newaxis]

print a
print b

print a+b

 

基本上是只在對應的 dimension 進行加減,擴展的部分不參與運算。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM