torch.argmax()函數
argmax函數:torch.argmax(input, dim=None, keepdim=False)
返回指定維度最大值的序號,dim給定的定義是:the demention to reduce.也就是把dim這個維度的,變成這個維度的最大值的index。
例如tensor(2, 3, 4)
dim=0,將第一維度去掉,這樣結果為tensor(3, 4)
。
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=0)
print(b)
print(a.shape)
"""
tensor([[0, 1, 0, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]])
torch.Size([2, 3, 4])
"""
# 去掉第一維度,這樣剩下兩個3x4數組,將兩個數組的對應位置進行比較,例如:a[0][0][0]和a[1][0][0]比較
#因為a[1][0][0]大,所以b[0][0][0]就是1,以此類推 這里的0和1表示對應的位置,第0個數組還是第1數組大
dim=1, 將第二維度去掉,取每一列的最大值。結果展示為tensor(2, 4)
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=1)
print(b)
print(a.shape)
"""
tensor([[1, 2, 0, 1],
[1, 2, 2, 1]])
torch.Size([2, 3, 4])
"""
# 去掉第二維度,結果為是一個2x4,將每一個3x4數組,變成1x4數組,經過變化后a[0] = tensor([9, 7, 5, 8])
#取每一列的最大值,a[0]中第一列的最大值的行標為1, 第二列的最大值的行標為2,第三列的最大值行標為0,第4列的最大值行標為1
#所以最后輸出[1, 2, 0, 1]
#以此類推,
dim=2, 將第三維度去掉,取每一行的最大值。結果展示為tensor(2, 3)
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=2)
print(b)
print(a.shape)
"""
tensor([[2, 0, 1],
[1, 0, 2]])
torch.Size([2, 3, 4])
"""
# 去掉第三維度,結果為是一個2x3,將每一個3x4數組,變成3x1數組,就好像經過變化后a[0] = tensor([5, 9, 7]的轉置)
#取每一行的最大值,a[0]中第一行的最大值的列標為2, 第二行的最大值的列標為0,第三行的最大值列標為1,
#所以最后輸出[2, 0, 1]
#以此類推,