pytorch01 torch.distributions.Normal和.log_prob()


Normal(means, sigma)的目的在於設置一個高斯分布

其中means的形狀和sigma的形狀可以不一致,遵循廣播原理

from torch.distributions import Normal

mu = torch.tensor([1, 10], dtype=torch.float32)
sigma = torch.tensor([1], dtype=torch.float32)
 
dist = Normal(mu, sigma)  # 設置高斯分布的均值和方差
 
dist.sample()  # 采樣

>>> tensor([-0.9153,  8.3727])

設置的高斯分布中sigma雖然只傳入了1,這里應該是廣播機制,會生成一個二維高斯分布,[N(1,1), N(10, 1)]

對其進行采樣dist.sample(),會得到一個數組

>>> tensor([-0.9153,  8.3727])

log_prob(x)用來計算輸入數據x在分布中的對於概率密度的對數

x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2)

dist.log_prob(x)

>>> tensor([[ -0.9189,  -0.9189],
        [-41.4189, -41.4189]])

其中,x = [[1, 10], [10, 1]],這一個x包括兩個數組[1, 10]和[10, 1]

在log_prob(x)的計算中,分別算的是這兩個數組在分布[N(1,1), N(10, 1)]上的概率密度,即:

[1, 10]的1在分布N(1, 1)  上的概率密度對數  -0.9189

[1, 10]的10在分布N(10, 1)上的概率密度對數 -0.9189

為輸出結果

>>> tensor([[ -0.9189,  -0.9189],
        [-41.4189, -41.4189]])

中的第一行數據

[10, 1]的10在分布N(1, 1)  上的概率密度對數  -41.4189

[10, 1]的1在分布N(10, 1)  上的概率密度對數  -41.4189

 為輸出結果的第二行。

上述說明,這個分布的計算其實就是簡單的多個1維不相關的高斯分布分別計算對應位置的數據的 對數 概率密度。

 

每一個數據的計算公式為:

 

 

 如果直接想知道概率密度是多少可以在后面加上  .exp()  還原

dist.log_prob(x).exp()

>>> tensor([[3.9894e-01, 3.9894e-01],
        [1.0280e-18, 1.0280e-18]])

 

 

 

 如圖,這是N(1, 1)的概率密度函數在x=1時的取值,從圖上大概看出在0.4左右,與dist.log_prob(x).exp()的結果一致

 

再舉一個例子:

 1 mu = torch.tensor([1, 10], dtype=torch.float32)
 2 sigma = torch.tensor([[1, 10], [5, 6]], dtype=torch.float32)
 3 distribution1 = torch.distributions.Normal(mu, sigma)
 4 distribution1.sample()
 5 >>> tensor([[ 2.8919, 29.8885],
 6                     [-4.1843, 13.6703]])
 7 x = torch.tensor([1, 10, 10, 1], dtype=torch.float32).reshape(-1, 2)
 8 distribution1.log_prob(x)
 9 >>> tensor([[-0.9189, -3.2215],
10                     [-4.1484, -3.8357]])

廣播機制,將生成一個多維的高斯分布distribution1,如下

[[N(1, 1), N(10, 10)],

 [N(1, 5), N(10, 6)]]

采樣的時候對每個位置的一維高斯分布采樣就得到一個2x2的矩陣

x也是一個2x2的矩陣,計算概率密度時候就正好直接對x的每個元素與distribution1中每個對應位置的高斯分布計算概率密度函數。

 

理解(僅供參考)

在多元混合高斯分布中,方差矩陣是各個變量的協方差矩陣,表達了他們之間的相互關系。

程序里的方差(矩陣)就是每一個一維高斯分布的方差,之間是沒有關系的,或者是說,這是正交解耦之后的方差。

所以如果由一個神經網絡輸出程序里多維高斯分布的參數,那么不同分布之間的關系其實是蘊含在神經網絡中的,也就是說第一個位置是n(0, 1)時第二個位置為什么是N(5, 6),這個原因就由神經網絡表達。(不知道這樣理解的對不對)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM