技術背景
在前一篇文章中,我們提到了關於Numpy中的各種取index的方法,可以用於取出數組里面的元素,也可以用於做切片,甚至可以用來做排序。但是遇到對於高維矩陣的某一個維度取多個值的時候,單純的使用下標已經無法完成相關的操作了。如果找不到相應的接口,對於性能要求不高的場景可以使用一個for循環進行替代,但是對於性能要求比較高的場景下,我們還是盡可能的使用Numpy本身自帶的接口,比如本文將要提到的take_along_axis操作。
使用案例
我們考慮這樣的一個場景,給定一個維度為(4,11,3)的矩陣a作為數據,和一個維度為(4,2)的矩陣b作為下標,意味着從a中第二條軸的11個元素中每次取兩個元素,也就是希望得到一個維度為(4,2,3)的結果:
In [11]: a = np.arange(132).reshape((4,11,3))
In [12]: a
Out[12]:
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[ 12, 13, 14],
[ 15, 16, 17],
[ 18, 19, 20],
[ 21, 22, 23],
[ 24, 25, 26],
[ 27, 28, 29],
[ 30, 31, 32]],
[[ 33, 34, 35],
[ 36, 37, 38],
[ 39, 40, 41],
[ 42, 43, 44],
[ 45, 46, 47],
[ 48, 49, 50],
[ 51, 52, 53],
[ 54, 55, 56],
[ 57, 58, 59],
[ 60, 61, 62],
[ 63, 64, 65]],
[[ 66, 67, 68],
[ 69, 70, 71],
[ 72, 73, 74],
[ 75, 76, 77],
[ 78, 79, 80],
[ 81, 82, 83],
[ 84, 85, 86],
[ 87, 88, 89],
[ 90, 91, 92],
[ 93, 94, 95],
[ 96, 97, 98]],
[[ 99, 100, 101],
[102, 103, 104],
[105, 106, 107],
[108, 109, 110],
[111, 112, 113],
[114, 115, 116],
[117, 118, 119],
[120, 121, 122],
[123, 124, 125],
[126, 127, 128],
[129, 130, 131]]])
In [13]: b = np.array([[0,1],[1,2],[2,3],[3,4]])
In [14]: b
Out[14]:
array([[0, 1],
[1, 2],
[2, 3],
[3, 4]])
為了方便展示我們就定義了這樣兩個比較簡單的矩陣a和b,那么在這個結果中,我們理想的結果應該是:
[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[108, 109, 110],
[111, 112, 113]]]
這樣的一個矩陣。關於這個結果的來源,可以對b這個定義進行展開解釋,b的值為:
[[0, 1],
[1, 2],
[2, 3],
[3, 4]]
它所表示的是在a[0]下取第0個元素和第1個元素,在a[1]下取第1個元素和第2個元素,以此類推。然而如果我們直接把定義好的b放到a的索引中或者直接使用numpy.take的方法的話,得到的結果是這樣的:
In [16]: a[:,b]
Out[16]:
array([[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 3, 4, 5],
[ 6, 7, 8]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[ 9, 10, 11],
[ 12, 13, 14]]],
[[[ 33, 34, 35],
[ 36, 37, 38]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 39, 40, 41],
[ 42, 43, 44]],
[[ 42, 43, 44],
[ 45, 46, 47]]],
[[[ 66, 67, 68],
[ 69, 70, 71]],
[[ 69, 70, 71],
[ 72, 73, 74]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[ 75, 76, 77],
[ 78, 79, 80]]],
[[[ 99, 100, 101],
[102, 103, 104]],
[[102, 103, 104],
[105, 106, 107]],
[[105, 106, 107],
[108, 109, 110]],
[[108, 109, 110],
[111, 112, 113]]]])
顯然這不是我們想要的結果。需要額外申明的是,這個執行操作中,最后一個維度的冒號加與不加是一樣的效果,跟numpy.take本質上也是同樣的操作,因此就需要使用到numpy中的另外一個接口:take_along_axis
,如下是其官方的API文檔:
還有相關的使用案例:
需要注意的是,輸入的indices必須要跟原始的數據矩陣保持同樣的維度,因此在我們自己的案例中,對b進行了擴維,最終的代碼如下所示:
In [23]: np.take_along_axis(a,b[:,:,None],axis=1)
Out[23]:
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 36, 37, 38],
[ 39, 40, 41]],
[[ 72, 73, 74],
[ 75, 76, 77]],
[[108, 109, 110],
[111, 112, 113]]])
最后得到的就是我們想要的結果了,並且是直接使用下標無法實現的操作(當然,也可能是我還沒研究出來這樣的操作)。這里axis設置為1,就表示a的第0個維度和b的第0個維度是一致的取法,也可以理解成全取的意思。
總結概要
Numpy是在Python中用於各種矩陣運算非常強大的工具之一,而快速的通過下標取出所需位置的元素也是numpy所支持的強大功能之一。常規的元素取法都可以通過numpy的下標或者是numpy.take函數來實現,比如array[0,:]可用於取第一條軸的所有元素,array[:,0]可以用於取第二條軸的所有第二個元素,放在一個2維的矩陣里面就分別是取第一行的所有元素和取第一列的所有元素。但是本文更加關注於更高維的矩陣,當我們想從多個維度中取多個元素時,是不太容易直接用下標去取的,比如同時取a[0][0],a[0][1],a[1][1],a[1][2]的話,那么就只能使用numpy所支持的另外一個函數numpy.take_along_axis來實現。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/take_along_axis.html
作者ID:DechinPhy
更多原著文章請參考:https://www.cnblogs.com/dechinphy/
打賞專用鏈接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
騰訊雲專欄同步:https://cloud.tencent.com/developer/column/91958