pytorch01 torch.distributions.Normal和.log_prob()


Normal(means, sigma)的目的在于设置一个高斯分布

其中means的形状和sigma的形状可以不一致,遵循广播原理

from torch.distributions import Normal

mu = torch.tensor([1, 10], dtype=torch.float32)
sigma = torch.tensor([1], dtype=torch.float32)
 
dist = Normal(mu, sigma)  # 设置高斯分布的均值和方差
 
dist.sample()  # 采样

>>> tensor([-0.9153,  8.3727])

设置的高斯分布中sigma虽然只传入了1,这里应该是广播机制,会生成一个二维高斯分布,[N(1,1), N(10, 1)]

对其进行采样dist.sample(),会得到一个数组

>>> tensor([-0.9153,  8.3727])

log_prob(x)用来计算输入数据x在分布中的对于概率密度的对数

x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2)

dist.log_prob(x)

>>> tensor([[ -0.9189,  -0.9189],
        [-41.4189, -41.4189]])

其中,x = [[1, 10], [10, 1]],这一个x包括两个数组[1, 10]和[10, 1]

在log_prob(x)的计算中,分别算的是这两个数组在分布[N(1,1), N(10, 1)]上的概率密度,即:

[1, 10]的1在分布N(1, 1)  上的概率密度对数  -0.9189

[1, 10]的10在分布N(10, 1)上的概率密度对数 -0.9189

为输出结果

>>> tensor([[ -0.9189,  -0.9189],
        [-41.4189, -41.4189]])

中的第一行数据

[10, 1]的10在分布N(1, 1)  上的概率密度对数  -41.4189

[10, 1]的1在分布N(10, 1)  上的概率密度对数  -41.4189

 为输出结果的第二行。

上述说明,这个分布的计算其实就是简单的多个1维不相关的高斯分布分别计算对应位置的数据的 对数 概率密度。

 

每一个数据的计算公式为:

 

 

 如果直接想知道概率密度是多少可以在后面加上  .exp()  还原

dist.log_prob(x).exp()

>>> tensor([[3.9894e-01, 3.9894e-01],
        [1.0280e-18, 1.0280e-18]])

 

 

 

 如图,这是N(1, 1)的概率密度函数在x=1时的取值,从图上大概看出在0.4左右,与dist.log_prob(x).exp()的结果一致

 

再举一个例子:

 1 mu = torch.tensor([1, 10], dtype=torch.float32)
 2 sigma = torch.tensor([[1, 10], [5, 6]], dtype=torch.float32)
 3 distribution1 = torch.distributions.Normal(mu, sigma)
 4 distribution1.sample()
 5 >>> tensor([[ 2.8919, 29.8885],
 6                     [-4.1843, 13.6703]])
 7 x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2)
 8 distribution1.log_prob(x)
 9 >>> tensor([[-0.9189, -3.2215],
10                     [-4.1484, -3.8357]])

广播机制,将生成一个多维的高斯分布distribution1,如下

[[N(1, 1), N(10, 10)],

 [N(1, 5), N(10, 6)]]

采样的时候对每个位置的一维高斯分布采样就得到一个2x2的矩阵

x也是一个2x2的矩阵,计算概率密度时候就正好直接对x的每个元素与distribution1中每个对应位置的高斯分布计算概率密度函数。

 

理解(仅供参考)

在多元混合高斯分布中,方差矩阵是各个变量的协方差矩阵,表达了他们之间的相互关系。

程序里的方差(矩阵)就是每一个一维高斯分布的方差,之间是没有关系的,或者是说,这是正交解耦之后的方差。

所以如果由一个神经网络输出程序里多维高斯分布的参数,那么不同分布之间的关系其实是蕴含在神经网络中的,也就是说第一个位置是n(0, 1)时第二个位置为什么是N(5, 6),这个原因就由神经网络表达。(不知道这样理解的对不对)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM