理解 numpy.rollaxis() 函数


函数声明

先看看 numpy.rollaxis() 函数的定义形式,如下:

rollaxis(a, axis, start=0)

参数 a 通常为 numpy.ndarray 类型,则 a.ndim表示 numpy 数组的维数;

参数 axis 通常为 int 类型,范围为 [0, a.ndim);

参数 start 为 int 类型,默认值为 0,取值范围为 [-a.ndim, a.ndim],如果超过这个范围,则会 raise AxisError。

函数功能

numpy.rollaxis() 函数用于滚动(roll)指定轴(axis)到某位置。这个函数可以用更易理解的函数 numpy.moveaxis(a, source, destination) 代替。但由于 numpy.moveaxis() 函数是在 numpy v1.11 版本新增的,为了与之前的版本兼容,这个函数依旧保留。

具体来说,需要根据 axis 和 normalized start 的比较结果,选择将 axis 滚动到哪个位置上,而其他轴的位置顺序不变。如果 axis 参数值大于或等于 normalized start,则 axis 从后向前滚动,直到 start 位置;如果 axis 参数值小于 normalized start,则 axis 轴从前往后滚动,直到 start 的前一个位置,即 start-1 位置。其中 start 和 normalized start 的对应关系,如下表所示:

start

Normalized start

-(a.ndim+1)

raise AxisError

-a.ndim

0

-1

a.ndim-1

0

0

a.ndim

a.ndim

a.ndim+1

raise AxisError

从表中,可以看出 normalized start 是在 -a.ndim <= start < 0 时, start + a.ndim 的值;在  0 <= start <= a.ndim 时,start 值。

具体的示例及解释,如下所示

import numpy as np a = np.ones((3,4,5,6)) axis, start = 3, 1
# 因为 3 > 1,所以 axis index 3 移动到 axis index 1(start位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,6,4,5) # np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start).shape) axis, start = 2, 0 # 因为 2 > 0,所以 axis index 2 移动到 axis index 0(start位置),而其他维度位置不变
print(np.rollaxis(a, axis, start).shape)  # (5,3,4,6) # np.moveaxis 的等价调用
print(np.moveaxis(a, axis, start).shape) axis, start = 1, 4
# 因为 1 < 4,所以 axis index 1 移动到 axis index 3(start-1位置),而其他维度位置不变
print(np.rollaxis(a, axis=axis, start=start).shape)  # (3,5,6,4) # np.moveaxis 的等价调用
print(np.moveaxis(a, source=axis, destination=start-1).shape)

 为了更好理解这个过程,最后看看该函数在 numpy 中实现的核心代码,如下所示:

def rollaxis(a, axis, start=0): """ Roll the specified axis backwards, until it lies in a given position. Parameters ---------- a : ndarray Input array. axis : int The axis to be rolled. The positions of the other axes do not change relative to one another. start : int, optional When ``start <= axis``, the axis is rolled back until it lies in this position. When ``start > axis``, the axis is rolled until it lies before this position. The default, 0, results in a "complete" roll. Returns ------- res : ndarray For NumPy >= 1.10.0 a view of `a` is always returned. For earlier NumPy versions a view of `a` is returned only if the order of the axes is changed, otherwise the input array is returned. """ n = a.ndim if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
    if not (0 <= start < n + 1): raise AxisError(msg % ('start', -n, 'start', n + 1, start)) if axis < start:
        start -= 1
    if axis == start: return a[...] axes = list(range(0, n)) axes.remove(axis) axes.insert(start, axis) return a.transpose(axes)

参考资料

[1] numpy.rollaxis API reference. https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM