Pytorch 中對 tensor 的很多操作如 sum
、argmax
、等都可以設置 dim
參數用來指定操作在哪一維進行。Pytorch 中的 dim 類似於 numpy 中的 axis,這篇文章來總結一下 Pytorch 中的 dim 操作。
dim 與方括號的關系
創建一個矩陣
a = torch.tensor([[1, 2], [3, 4]])
print(a)
輸出
tensor([[1, 2],
[3, 4]])
因為a
是一個矩陣,所以a
的左邊有 2 個括號
括號之間是嵌套關系,代表了不同的維度。從左往右數,兩個括號代表的維度分別是 0 和 1 ,在第 0 維遍歷得到向量,在第 1 維遍歷得到標量
同樣地,對於 3 維 tensor
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
輸出
tensor([[[3, 2],
[1, 4]],
[[5, 6],
[7, 8]]])
則 3 個括號代表的維度從左往右分別為 0, 1, 2,在第 0 維遍歷得到矩陣,在第 1 維遍歷得到向量,在第 2 維遍歷得到標量
更詳細一點
在指定的維度上進行操作
在某一維度求和(或者進行其他操作)就是對該維度中的元素進行求和。
對於矩陣 a
a = torch.tensor([[1, 2], [3, 4]])
print(a)
輸出
tensor([[1, 2],
[3, 4]])
求 a 在第 0 維的和,因為第 0 維代表最外邊的括號,括號中的元素為向量[1, 2]
,[3, 4]
,第 0 維的和就是第 0 維中的元素相加,也就是兩個向量[1, 2]
,[3, 4]
相加,所以結果為
s = torch.sum(a, dim=0)
print(s)
輸出
tensor([4, 6])
可以看到,a 是 2 維矩陣,而相加的結果為 1 維向量,可以使用參數keepdim=True
來保證形狀不變
s = torch.sum(a, dim=0, keepdim=True)
print(s)
輸出
tensor([[4, 6]])
在 a 的第 0 維求和,就是對第 0 維中的元素(向量)進行相加。同樣的,對 a 第 1 維求和,就是對 a 第 1 維中的元素(標量)進行相加,a 的第 1 維元素為標量 1,2 和 3,4,則結果為
s = torch.sum(a, dim=1)
print(s)
輸出
tensor([3, 7])
保持維度不變
s = torch.sum(a, dim=1, keepdim=True)
print(s)
輸出
tensor([[3],
[7]])
對 3 維 tensor 的操作也是這樣
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
輸出
tensor([[[3, 2],
[1, 4]],
[[5, 6],
[7, 8]]])
將 b 在第 0 維相加,第 0 維為最外層括號,最外層括號中的元素為矩陣[[3, 2], [1, 4]]
和[[5, 6], [7, 8]]
。在第 0 維求和,就是將第 0 維中的元素(矩陣)相加
s = torch.sum(b, dim=0)
print(s)
輸出
tensor([[ 8, 8],
[ 8, 12]])
求 b 在第 1 維的和,就是將 b 第 1 維中的元素[3, 2]
和[1, 4]
, [5, 6]
和 [7, 8]
相加,所以
s = torch.sum(b, dim=1)
print(s)
輸出
tensor([[ 4, 6],
[12, 14]])
則在 b 的第 2 維求和,就是對標量 3 和 2, 1 和 4, 5 和 6 , 7 和 8 求和
s = torch.sum(b, dim=2)
print(s)
結果為
tensor([[ 5, 5],
[11, 15]])
除了求和,其他操作也是類似的,如求 b 在指定維度上的最大值
m = torch.max(b, dim=0)
print(m)
b 在第 0 維的最大值是第 0 維中的元素(兩個矩陣[[3, 2], [1, 4]]
和[[5, 6], [7, 8]]
)的最大值,取矩陣對應位置最大值即可
結果為
torch.return_types.max(
values=tensor([[5, 6],
[7, 8]]),
indices=tensor([[1, 1],
[1, 1]]))
b 在第 1 維的最大值就是第 1 維元素(4 個(2對)向量)的最大值
m = torch.max(b, dim=1)
print(m)
輸出為
torch.return_types.max(
values=tensor([[3, 4],
[7, 8]]),
indices=tensor([[0, 1],
[1, 1]]))
b 在第 0 維的最大值就是第 0 為元素(8 個(4 對)標量)的最大值
m = torch.max(b, dim=2)
print(m)
輸出
torch.return_types.max(
values=tensor([[3, 4],
[6, 8]]),
indices=tensor([[0, 1],
[1, 1]]))
總結
在 tensor 的指定維度操作就是對指定維度包含的元素進行操作,如果想要保持結果的維度不變,設置參數keepdim=True
即可。