拒絕for循環,從take_along_axis開始


技術背景

在前一篇文章中,我們提到了關於Numpy中的各種取index的方法,可以用於取出數組里面的元素,也可以用於做切片,甚至可以用來做排序。但是遇到對於高維矩陣的某一個維度取多個值的時候,單純的使用下標已經無法完成相關的操作了。如果找不到相應的接口,對於性能要求不高的場景可以使用一個for循環進行替代,但是對於性能要求比較高的場景下,我們還是盡可能的使用Numpy本身自帶的接口,比如本文將要提到的take_along_axis操作。

使用案例

我們考慮這樣的一個場景,給定一個維度為(4,11,3)的矩陣a作為數據,和一個維度為(4,2)的矩陣b作為下標,意味着從a中第二條軸的11個元素中每次取兩個元素,也就是希望得到一個維度為(4,2,3)的結果:

In [11]: a = np.arange(132).reshape((4,11,3))

In [12]: a
Out[12]: 
array([[[  0,   1,   2],
        [  3,   4,   5],
        [  6,   7,   8],
        [  9,  10,  11],
        [ 12,  13,  14],
        [ 15,  16,  17],
        [ 18,  19,  20],
        [ 21,  22,  23],
        [ 24,  25,  26],
        [ 27,  28,  29],
        [ 30,  31,  32]],

       [[ 33,  34,  35],
        [ 36,  37,  38],
        [ 39,  40,  41],
        [ 42,  43,  44],
        [ 45,  46,  47],
        [ 48,  49,  50],
        [ 51,  52,  53],
        [ 54,  55,  56],
        [ 57,  58,  59],
        [ 60,  61,  62],
        [ 63,  64,  65]],

       [[ 66,  67,  68],
        [ 69,  70,  71],
        [ 72,  73,  74],
        [ 75,  76,  77],
        [ 78,  79,  80],
        [ 81,  82,  83],
        [ 84,  85,  86],
        [ 87,  88,  89],
        [ 90,  91,  92],
        [ 93,  94,  95],
        [ 96,  97,  98]],

       [[ 99, 100, 101],
        [102, 103, 104],
        [105, 106, 107],
        [108, 109, 110],
        [111, 112, 113],
        [114, 115, 116],
        [117, 118, 119],
        [120, 121, 122],
        [123, 124, 125],
        [126, 127, 128],
        [129, 130, 131]]])

In [13]: b = np.array([[0,1],[1,2],[2,3],[3,4]])

In [14]: b
Out[14]: 
array([[0, 1],
       [1, 2],
       [2, 3],
       [3, 4]])

為了方便展示我們就定義了這樣兩個比較簡單的矩陣a和b,那么在這個結果中,我們理想的結果應該是:

[[[  0,   1,   2],
  [  3,   4,   5]],

 [[ 36,  37,  38],
  [ 39,  40,  41]],

 [[ 72,  73,  74],
  [ 75,  76,  77]],

 [[108, 109, 110],
  [111, 112, 113]]]

這樣的一個矩陣。關於這個結果的來源,可以對b這個定義進行展開解釋,b的值為:

[[0, 1],
 [1, 2],
 [2, 3],
 [3, 4]]

它所表示的是在a[0]下取第0個元素和第1個元素,在a[1]下取第1個元素和第2個元素,以此類推。然而如果我們直接把定義好的b放到a的索引中或者直接使用numpy.take的方法的話,得到的結果是這樣的:

In [16]: a[:,b]
Out[16]: 
array([[[[  0,   1,   2],
         [  3,   4,   5]],

        [[  3,   4,   5],
         [  6,   7,   8]],

        [[  6,   7,   8],
         [  9,  10,  11]],

        [[  9,  10,  11],
         [ 12,  13,  14]]],


       [[[ 33,  34,  35],
         [ 36,  37,  38]],

        [[ 36,  37,  38],
         [ 39,  40,  41]],

        [[ 39,  40,  41],
         [ 42,  43,  44]],

        [[ 42,  43,  44],
         [ 45,  46,  47]]],


       [[[ 66,  67,  68],
         [ 69,  70,  71]],

        [[ 69,  70,  71],
         [ 72,  73,  74]],

        [[ 72,  73,  74],
         [ 75,  76,  77]],

        [[ 75,  76,  77],
         [ 78,  79,  80]]],


       [[[ 99, 100, 101],
         [102, 103, 104]],

        [[102, 103, 104],
         [105, 106, 107]],

        [[105, 106, 107],
         [108, 109, 110]],

        [[108, 109, 110],
         [111, 112, 113]]]])

顯然這不是我們想要的結果。需要額外申明的是,這個執行操作中,最后一個維度的冒號加與不加是一樣的效果,跟numpy.take本質上也是同樣的操作,因此就需要使用到numpy中的另外一個接口:take_along_axis,如下是其官方的API文檔:

還有相關的使用案例:

需要注意的是,輸入的indices必須要跟原始的數據矩陣保持同樣的維度,因此在我們自己的案例中,對b進行了擴維,最終的代碼如下所示:

In [23]: np.take_along_axis(a,b[:,:,None],axis=1)
Out[23]: 
array([[[  0,   1,   2],
        [  3,   4,   5]],

       [[ 36,  37,  38],
        [ 39,  40,  41]],

       [[ 72,  73,  74],
        [ 75,  76,  77]],

       [[108, 109, 110],
        [111, 112, 113]]])

最后得到的就是我們想要的結果了,並且是直接使用下標無法實現的操作(當然,也可能是我還沒研究出來這樣的操作)。這里axis設置為1,就表示a的第0個維度和b的第0個維度是一致的取法,也可以理解成全取的意思。

總結概要

Numpy是在Python中用於各種矩陣運算非常強大的工具之一,而快速的通過下標取出所需位置的元素也是numpy所支持的強大功能之一。常規的元素取法都可以通過numpy的下標或者是numpy.take函數來實現,比如array[0,:]可用於取第一條軸的所有元素,array[:,0]可以用於取第二條軸的所有第二個元素,放在一個2維的矩陣里面就分別是取第一行的所有元素和取第一列的所有元素。但是本文更加關注於更高維的矩陣,當我們想從多個維度中取多個元素時,是不太容易直接用下標去取的,比如同時取a[0][0],a[0][1],a[1][1],a[1][2]的話,那么就只能使用numpy所支持的另外一個函數numpy.take_along_axis來實現。

版權聲明

本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/take_along_axis.html

作者ID:DechinPhy

更多原著文章請參考:https://www.cnblogs.com/dechinphy/

打賞專用鏈接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

騰訊雲專欄同步:https://cloud.tencent.com/developer/column/91958

參考鏈接

  1. https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html#numpy.take_along_axis


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM