對於numpy中的函數的參數dim的一點理解
經常被dim參數搞混。試着總結了一下。記憶瞬間清晰了
以.max(dim)方法為例:
>>> import numpy as np
>>> a = np.random.randint(1, 100, [2, 3, 4])
>>> a
array([[[26, 36, 31, 21],
[74, 59, 79, 32],
[77, 94, 81, 32]],
[[72, 76, 85, 93],
[66, 34, 80, 12],
[99, 17, 98, 23]]])
>>> for i in range(3):
... print(a.max(i))
...
[[72 76 85 93]
[74 59 80 32]
[99 94 98 32]]
[[77 94 81 32]
[99 76 98 93]]
[[36 79 94]
[93 80 99]]
可以見得:
a是一個2x3x4的三維矩陣。
當a.max(0)時,max則在維度大小為2的方向上進行操作,所以
a.max(0)就是:
[[72 76 85 93]
[74 59 80 32]
[99 94 98 32]]
一個 1x3x4的矩陣。
以此類推,a.max(1)就是在維度大小為3的方向上進行操作
a.max(i)就是:
[[77 94 81 32]
[99 76 98 93]]
一個 1x2x4的矩陣。
由此很容易發現。
.max(dim)中的dim,並不是a上的維度。而是指a的shape上的順序(可以這么理解),a的shape是2x3x4,也就是[2, 3, 4]。故可以這樣一一對應以來。
而不用死記硬背那些0是對列操作還是對行操作了