函數聲明
先看看 numpy.rollaxis() 函數的定義形式,如下:
rollaxis(a, axis, start=0)
參數 a 通常為 numpy.ndarray 類型,則 a.ndim表示 numpy 數組的維數;
參數 axis 通常為 int 類型,范圍為 [0, a.ndim);
參數 start 為 int 類型,默認值為 0,取值范圍為 [-a.ndim, a.ndim],如果超過這個范圍,則會 raise AxisError。
函數功能
numpy.rollaxis() 函數用於滾動(roll)指定軸(axis)到某位置。這個函數可以用更易理解的函數 numpy.moveaxis(a, source, destination) 代替。但由於 numpy.moveaxis() 函數是在 numpy v1.11 版本新增的,為了與之前的版本兼容,這個函數依舊保留。
具體來說,需要根據 axis 和 normalized start 的比較結果,選擇將 axis 滾動到哪個位置上,而其他軸的位置順序不變。如果 axis 參數值大於或等於 normalized start,則 axis 從后向前滾動,直到 start 位置;如果 axis 參數值小於 normalized start,則 axis 軸從前往后滾動,直到 start 的前一個位置,即 start-1 位置。其中 start 和 normalized start 的對應關系,如下表所示:
start |
Normalized start |
-(a.ndim+1) |
raise AxisError |
-a.ndim |
0 |
⋮ |
⋮ |
-1 |
a.ndim-1 |
0 |
0 |
⋮ |
⋮ |
a.ndim |
a.ndim |
a.ndim+1 |
raise AxisError |
從表中,可以看出 normalized start 是在 -a.ndim <= start < 0 時, start + a.ndim 的值;在 0 <= start <= a.ndim 時,start 值。
具體的示例及解釋,如下所示
import numpy as np a = np.ones((3,4,5,6)) axis, start = 3, 1
# 因為 3 > 1,所以 axis index 3 移動到 axis index 1(start位置),而其他維度位置不變
print(np.rollaxis(a, axis=axis, start=start).shape) # (3,6,4,5) # np.moveaxis 的等價調用
print(np.moveaxis(a, source=axis, destination=start).shape) axis, start = 2, 0 # 因為 2 > 0,所以 axis index 2 移動到 axis index 0(start位置),而其他維度位置不變
print(np.rollaxis(a, axis, start).shape) # (5,3,4,6) # np.moveaxis 的等價調用
print(np.moveaxis(a, axis, start).shape) axis, start = 1, 4
# 因為 1 < 4,所以 axis index 1 移動到 axis index 3(start-1位置),而其他維度位置不變
print(np.rollaxis(a, axis=axis, start=start).shape) # (3,5,6,4) # np.moveaxis 的等價調用
print(np.moveaxis(a, source=axis, destination=start-1).shape)
為了更好理解這個過程,最后看看該函數在 numpy 中實現的核心代碼,如下所示:
def rollaxis(a, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. Parameters ---------- a : ndarray Input array. axis : int The axis to be rolled. The positions of the other axes do not change relative to one another. start : int, optional When ``start <= axis``, the axis is rolled back until it lies in this position. When ``start > axis``, the axis is rolled until it lies before this position. The default, 0, results in a "complete" roll. Returns ------- res : ndarray For NumPy >= 1.10.0 a view of `a` is always returned. For earlier NumPy versions a view of `a` is returned only if the order of the axes is changed, otherwise the input array is returned. """ n = a.ndim if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= start < n + 1): raise AxisError(msg % ('start', -n, 'start', n + 1, start)) if axis < start:
start -= 1
if axis == start: return a[...] axes = list(range(0, n)) axes.remove(axis) axes.insert(start, axis) return a.transpose(axes)
參考資料
[1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html.