Normal(means, sigma)的目的在于设置一个高斯分布
其中means的形状和sigma的形状可以不一致,遵循广播原理
from torch.distributions import Normal mu = torch.tensor([1, 10], dtype=torch.float32) sigma = torch.tensor([1], dtype=torch.float32) dist = Normal(mu, sigma) # 设置高斯分布的均值和方差 dist.sample() # 采样 >>> tensor([-0.9153, 8.3727])
设置的高斯分布中sigma虽然只传入了1,这里应该是广播机制,会生成一个二维高斯分布,[N(1,1), N(10, 1)]
对其进行采样dist.sample(),会得到一个数组
>>> tensor([-0.9153, 8.3727])
log_prob(x)用来计算输入数据x在分布中的对于概率密度的对数
x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2) dist.log_prob(x) >>> tensor([[ -0.9189, -0.9189], [-41.4189, -41.4189]])
其中,x = [[1, 10], [10, 1]],这一个x包括两个数组[1, 10]和[10, 1]
在log_prob(x)的计算中,分别算的是这两个数组在分布[N(1,1), N(10, 1)]上的概率密度,即:
[1, 10]的1在分布N(1, 1) 上的概率密度对数 -0.9189
[1, 10]的10在分布N(10, 1)上的概率密度对数 -0.9189
为输出结果
>>> tensor([[ -0.9189, -0.9189],
[-41.4189, -41.4189]])
中的第一行数据
[10, 1]的10在分布N(1, 1) 上的概率密度对数 -41.4189
[10, 1]的1在分布N(10, 1) 上的概率密度对数 -41.4189
为输出结果的第二行。
上述说明,这个分布的计算其实就是简单的多个1维不相关的高斯分布分别计算对应位置的数据的 对数 概率密度。
每一个数据的计算公式为:
如果直接想知道概率密度是多少可以在后面加上 .exp() 还原
dist.log_prob(x).exp() >>> tensor([[3.9894e-01, 3.9894e-01], [1.0280e-18, 1.0280e-18]])
如图,这是N(1, 1)的概率密度函数在x=1时的取值,从图上大概看出在0.4左右,与dist.log_prob(x).exp()的结果一致
再举一个例子:
1 mu = torch.tensor([1, 10], dtype=torch.float32) 2 sigma = torch.tensor([[1, 10], [5, 6]], dtype=torch.float32) 3 distribution1 = torch.distributions.Normal(mu, sigma) 4 distribution1.sample() 5 >>> tensor([[ 2.8919, 29.8885], 6 [-4.1843, 13.6703]]) 7 x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2) 8 distribution1.log_prob(x) 9 >>> tensor([[-0.9189, -3.2215], 10 [-4.1484, -3.8357]])
广播机制,将生成一个多维的高斯分布distribution1,如下
[[N(1, 1), N(10, 10)],
[N(1, 5), N(10, 6)]]
采样的时候对每个位置的一维高斯分布采样就得到一个2x2的矩阵
x也是一个2x2的矩阵,计算概率密度时候就正好直接对x的每个元素与distribution1中每个对应位置的高斯分布计算概率密度函数。
理解(仅供参考)
在多元混合高斯分布中,方差矩阵是各个变量的协方差矩阵,表达了他们之间的相互关系。
程序里的方差(矩阵)就是每一个一维高斯分布的方差,之间是没有关系的,或者是说,这是正交解耦之后的方差。
所以如果由一个神经网络输出程序里多维高斯分布的参数,那么不同分布之间的关系其实是蕴含在神经网络中的,也就是说第一个位置是n(0, 1)时第二个位置为什么是N(5, 6),这个原因就由神经网络表达。(不知道这样理解的对不对)