Normal(means, sigma)的目的在於設置一個高斯分布
其中means的形狀和sigma的形狀可以不一致,遵循廣播原理
from torch.distributions import Normal mu = torch.tensor([1, 10], dtype=torch.float32) sigma = torch.tensor([1], dtype=torch.float32) dist = Normal(mu, sigma) # 設置高斯分布的均值和方差 dist.sample() # 采樣 >>> tensor([-0.9153, 8.3727])
設置的高斯分布中sigma雖然只傳入了1,這里應該是廣播機制,會生成一個二維高斯分布,[N(1,1), N(10, 1)]
對其進行采樣dist.sample(),會得到一個數組
>>> tensor([-0.9153, 8.3727])
log_prob(x)用來計算輸入數據x在分布中的對於概率密度的對數
x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2) dist.log_prob(x) >>> tensor([[ -0.9189, -0.9189], [-41.4189, -41.4189]])
其中,x = [[1, 10], [10, 1]],這一個x包括兩個數組[1, 10]和[10, 1]
在log_prob(x)的計算中,分別算的是這兩個數組在分布[N(1,1), N(10, 1)]上的概率密度,即:
[1, 10]的1在分布N(1, 1) 上的概率密度對數 -0.9189
[1, 10]的10在分布N(10, 1)上的概率密度對數 -0.9189
為輸出結果
>>> tensor([[ -0.9189, -0.9189],
[-41.4189, -41.4189]])
中的第一行數據
[10, 1]的10在分布N(1, 1) 上的概率密度對數 -41.4189
[10, 1]的1在分布N(10, 1) 上的概率密度對數 -41.4189
為輸出結果的第二行。
上述說明,這個分布的計算其實就是簡單的多個1維不相關的高斯分布分別計算對應位置的數據的 對數 概率密度。
每一個數據的計算公式為:

如果直接想知道概率密度是多少可以在后面加上 .exp() 還原
dist.log_prob(x).exp() >>> tensor([[3.9894e-01, 3.9894e-01], [1.0280e-18, 1.0280e-18]])

如圖,這是N(1, 1)的概率密度函數在x=1時的取值,從圖上大概看出在0.4左右,與dist.log_prob(x).exp()的結果一致
再舉一個例子:
1 mu = torch.tensor([1, 10], dtype=torch.float32) 2 sigma = torch.tensor([[1, 10], [5, 6]], dtype=torch.float32) 3 distribution1 = torch.distributions.Normal(mu, sigma) 4 distribution1.sample() 5 >>> tensor([[ 2.8919, 29.8885], 6 [-4.1843, 13.6703]]) 7 x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2) 8 distribution1.log_prob(x) 9 >>> tensor([[-0.9189, -3.2215], 10 [-4.1484, -3.8357]])
廣播機制,將生成一個多維的高斯分布distribution1,如下
[[N(1, 1), N(10, 10)],
[N(1, 5), N(10, 6)]]
采樣的時候對每個位置的一維高斯分布采樣就得到一個2x2的矩陣
x也是一個2x2的矩陣,計算概率密度時候就正好直接對x的每個元素與distribution1中每個對應位置的高斯分布計算概率密度函數。
理解(僅供參考)
在多元混合高斯分布中,方差矩陣是各個變量的協方差矩陣,表達了他們之間的相互關系。
程序里的方差(矩陣)就是每一個一維高斯分布的方差,之間是沒有關系的,或者是說,這是正交解耦之后的方差。
所以如果由一個神經網絡輸出程序里多維高斯分布的參數,那么不同分布之間的關系其實是蘊含在神經網絡中的,也就是說第一個位置是n(0, 1)時第二個位置為什么是N(5, 6),這個原因就由神經網絡表達。(不知道這樣理解的對不對)
