torch.distributions.Categorical()
功能:根據概率分布來產生sample,產生的sample是輸入tensor的index
如:
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample() # equal probability of 0, 1, 2, 3
tensor(3)