pytorch中參數dim的含義(正負,零,不傳)


總結:

torch.function(x, dim)

1.if 不傳: 依照默認參數決定

2.if dim >=0 and dim <= x.dim()-1: 0是沿最粗數據粒度的方向進行操作,x.dim()-1是按最細粒度的方向。

3.if dim <0: dim的最小取值(此按照不同function而定)到最大取值(-1)之間。與情況2正好相反,最大的取值(-1)代表按最細粒度的方向,最小的取值按最粗粒度的方向。

 

實驗代碼:(使用torch.max(x, dim)為例子)

1.dim=2

m
Out[77]:
tensor([[1, 2, 3],
        [4, 5, 6]])

 

torch.max(m,)
Out[85]: tensor(6)

不傳:默認參數的設定是對整個傳入的數據進行操作


torch.max(m, dim=0)
Out[79]:
torch.return_types.max(
values=tensor([4, 5, 6]),
indices=tensor([1, 1, 1]))

此處最粗粒度是兩行之間[1, 2, 3]->[4, 5, 6]的方向,也就是常說是縱向進行操作。


torch.max(m, dim=1)
Out[78]:
torch.return_types.max(
values=tensor([3, 6]),
indices=tensor([2, 2]))

此處最細粒度是一行之內[1, 2, 3]的方向,也就是常說是橫向進行操作。

 

torch.max(m, dim=2)
Traceback (most recent call last):
  File "/home/xutianfan/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-84-ce6440fe62e4>", line 1, in <module>
    torch.max(m, dim=2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

 

torch.max(m, dim=-1)
Out[86]:
torch.return_types.max(
values=tensor([3, 6]),
indices=tensor([2, 2]))

-1+2=1,同torch.max(m, dim=1)結果。

 

torch.max(m, dim=-2)
Out[87]:
torch.return_types.max(
values=tensor([4, 5, 6]),
indices=tensor([1, 1, 1]))

 

 

2.dim=3(tensor)

t1
Out[89]:
tensor([[[0, 1, 2, 3],
         [1, 2, 3, 4]],
        [[2, 3, 4, 5],
         [4, 5, 6, 7]],
        [[5, 6, 7, 8],
         [6, 7, 8, 9]]])

torch.max(t1)
Out[94]: tensor(9)

 

torch.max(t1, dim=0)
Out[91]:
torch.return_types.max(
values=tensor([[5, 6, 7, 8],
        [6, 7, 8, 9]]),
indices=tensor([[2, 2, 2, 2],
        [2, 2, 2, 2]]))

最粗粒度是在各個矩陣之間的方向,所以對各個矩陣的每個位置分別取最大。

 

torch.max(t1, dim=1)
Out[92]:
torch.return_types.max(
values=tensor([[1, 2, 3, 4],
        [4, 5, 6, 7],
        [6, 7, 8, 9]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]]))

其次粗的粒度是矩陣中各行之間的方向

 

torch.max(t1, dim=2)
Out[93]:
torch.return_types.max(
values=tensor([[3, 4],
        [5, 7],
        [8, 9]]),
indices=tensor([[3, 3],
        [3, 3],
        [3, 3]]))
最細粒度是各行之內的方向。所以取出了各行中最大的元素。

 

torch.max(t1, dim=-1)
Out[97]:
torch.return_types.max(
values=tensor([[3, 4],
        [5, 7],
        [8, 9]]),
indices=tensor([[3, 3],
        [3, 3],
        [3, 3]]))

 

 

雖然我們這里只使用了max函數,但是這對於torch中其他函數(例如softmax)也有效。

可以有這種寫法:mean = x.mean(-1, keepdim=True)

這樣無論是對於2維還是3維的輸入,都自動dim=input.dim()-1,也就是從最細粒度取平均。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM