理解 numpy.rollaxis() 函數


函數聲明

先看看 numpy.rollaxis() 函數的定義形式,如下:

rollaxis(a, axis, start=0)

參數 a 通常為 numpy.ndarray 類型,則 a.ndim表示 numpy 數組的維數;

參數 axis 通常為 int 類型,范圍為 [0, a.ndim);

參數 start 為 int 類型,默認值為 0,取值范圍為 [-a.ndim, a.ndim],如果超過這個范圍,則會 raise AxisError。

函數功能

numpy.rollaxis() 函數用於滾動(roll)指定軸(axis)到某位置。這個函數可以用更易理解的函數 numpy.moveaxis(a, source, destination) 代替。但由於 numpy.moveaxis() 函數是在 numpy v1.11 版本新增的,為了與之前的版本兼容,這個函數依舊保留。

具體來說,需要根據 axis 和 normalized start 的比較結果,選擇將 axis 滾動到哪個位置上,而其他軸的位置順序不變。如果 axis 參數值大於或等於 normalized start,則 axis 從后向前滾動,直到 start 位置;如果 axis 參數值小於 normalized start,則 axis 軸從前往后滾動,直到 start 的前一個位置,即 start-1 位置。其中 start 和 normalized start 的對應關系,如下表所示:

start

Normalized start

-(a.ndim+1)

raise AxisError

-a.ndim

0

-1

a.ndim-1

0

0

a.ndim

a.ndim

a.ndim+1

raise AxisError

從表中,可以看出 normalized start 是在 -a.ndim <= start < 0 時, start + a.ndim 的值;在  0 <= start <= a.ndim 時,start 值。

具體的示例及解釋,如下所示

import numpy as np a = np.ones((3,4,5,6)) axis, start = 3, 1
# 因為 3 > 1,所以 axis index 3 移動到 axis index 1(start位置),而其他維度位置不變
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,6,4,5) # np.moveaxis 的等價調用
print(np.moveaxis(a, source=axis, destination=start).shape) axis, start = 2, 0 # 因為 2 > 0,所以 axis index 2 移動到 axis index 0(start位置),而其他維度位置不變
print(np.rollaxis(a, axis, start).shape)  # (5,3,4,6) # np.moveaxis 的等價調用
print(np.moveaxis(a, axis, start).shape) axis, start = 1, 4
# 因為 1 < 4,所以 axis index 1 移動到 axis index 3(start-1位置),而其他維度位置不變
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,5,6,4) # np.moveaxis 的等價調用
print(np.moveaxis(a, source=axis, destination=start-1).shape)

 為了更好理解這個過程,最后看看該函數在 numpy 中實現的核心代碼,如下所示:

def rollaxis(a, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. Parameters ---------- a : ndarray Input array. axis : int The axis to be rolled. The positions of the other axes do not change relative to one another. start : int, optional When ``start <= axis``, the axis is rolled back until it lies in this position. When ``start > axis``, the axis is rolled until it lies before this position. The default, 0, results in a "complete" roll. Returns ------- res : ndarray For NumPy >= 1.10.0 a view of `a` is always returned. For earlier NumPy versions a view of `a` is returned only if the order of the axes is changed, otherwise the input array is returned. """ n = a.ndim if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
    if not (0 <= start < n + 1): raise AxisError(msg % ('start', -n, 'start', n + 1, start)) if axis < start:
        start -= 1
    if axis == start: return a[...] axes = list(range(0, n)) axes.remove(axis) axes.insert(start, axis) return a.transpose(axes)

參考資料

[1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM