Pytorch中torch.nn.Softmax的dim参数含义


import torch.nn as nn

m = nn.Softmax(dim=0)

input = torch.randn(2, 2, 3)
print(input)
print(m(input))

 input:

tensor([[[ 0.5450, -0.6264,  1.0446],
         [ 0.6324,  1.9069,  0.7158]],

        [[ 1.0092,  0.2421, -0.8928],
         [ 0.0344,  0.9723,  0.4328]]])

 dim=0:

tensor([[[0.3860, 0.2956, 0.8741],
         [0.6452, 0.7180, 0.5703]],

        [[0.6140, 0.7044, 0.1259],
         [0.3548, 0.2820, 0.4297]]])

dim=0时,在第0维sum=1,即:

[[0.3860, 0.2956, 0.8741],
         [0.6452, 0.7180, 0.5703]] 和
[[0.6140, 0.7044, 0.1259],
         [0.3548, 0.2820, 0.4297]]对应位置和为1


[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …  

0.7044 第一层是1,第二层是0,第三层是1

dim=1时,在第1维上sum=1,即:

[0.3860, 0.2956, 0.8741]和
[0.6452, 0.7180, 0.5703]

  


[0][0][0]+[0][1][0]=0.3860+0.6452=1

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM