Pytorch中torch.nn.Softmax的dim參數含義


import torch.nn as nn

m = nn.Softmax(dim=0)

input = torch.randn(2, 2, 3)
print(input)
print(m(input))

 input:

tensor([[[ 0.5450, -0.6264,  1.0446],
         [ 0.6324,  1.9069,  0.7158]],

        [[ 1.0092,  0.2421, -0.8928],
         [ 0.0344,  0.9723,  0.4328]]])

 dim=0:

tensor([[[0.3860, 0.2956, 0.8741],
         [0.6452, 0.7180, 0.5703]],

        [[0.6140, 0.7044, 0.1259],
         [0.3548, 0.2820, 0.4297]]])

dim=0時,在第0維sum=1,即:

[[0.3860, 0.2956, 0.8741],
         [0.6452, 0.7180, 0.5703]] 和
[[0.6140, 0.7044, 0.1259],
         [0.3548, 0.2820, 0.4297]]對應位置和為1


[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …  

0.7044 第一層是1,第二層是0,第三層是1

dim=1時,在第1維上sum=1,即:

[0.3860, 0.2956, 0.8741]和
[0.6452, 0.7180, 0.5703]

  


[0][0][0]+[0][1][0]=0.3860+0.6452=1

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM