Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.
主要內容
自編碼, 通過引入Encoder和Decoder來估計聯合分布\(p(x,z)\), 其中\(z\)表示隱變量(我們也可以讓\(z\)為樣本標簽, 使得Encoder成為一個判別器).
在Decoder中我們建立聯合分布\(p_{\theta}(x,z)\)以估計\(p(x,z)\), 在Encoder中建立一個后驗分布\(q_{\phi}(z|x)\)去估計\(p_{\theta}(z|x)\), 然后極大似然:
上式倆邊關於\(z\)在分布\(q_{\phi}(z)\)下求期望可得:
既然KL散度非負, 我們極大似然\(\log p_{\theta}(x)\)可以退而求其次, 最大化\(\mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )\)(ELBO, 記為\(\mathcal{L}\)).
又(\(p_{\theta}(z)\)為認為給定的先驗分布)
我們接下來通過對Encoder和Decoder的一些構造進一步擴展上面倆項.
Encoder (損失part1)
Encoder 將\(x\rightarrow z\), 就相當於在\(q_{\phi}(z|x)\)中進行采樣, 但是如果是直接采樣的話, 就沒法利用梯度回傳進行訓練了, 這里需要一個重參化技巧.
我們假設\(q_{\phi}(z|x)\)為高斯密度函數, 即\(\mathcal{N}(\mu, \sigma^2 I)\).
注: 文中還提到了其他的一些可行假設.
我們構建一個神經網絡\(f\), 其輸入為樣本\(x\), 輸出為\((\mu, \log \sigma)\)(輸出\(\log \sigma\)是為了保證\(\sigma\)為正), 則
其中\(\odot\)表示按元素相乘.
注: 我們可以該輸出為\((\mu, L)\)(\(L\)為三角矩陣, 且對角線元素非負), 而假設\(q_{\phi}(z|x)\)的分量不獨立, 其協方差函數為\(L^TL\), 則\((z=\mu + L \epsilon\)).
當\(p_{\theta}(z)=\mathcal{N}(0, I)\), 我們可以顯示表達出:
Decoder (損失part2)
現在我們需要處理的是第二項, 文中這地方因為直接設計\(p_{\theta}(x,z)\)不容易, 在我看來存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先設計一個網絡\(g_{\theta}(z)\), 其輸出為\(\hat{x}\), 然后假設\(p(x|\hat{x})\)的分布, 第二項就改為近似\(\mathbb{E}_{q_{\phi}(z|x)}p_{\theta}(x|\hat{x})\).
這么做的好處是顯而易見的, 因為Decoder部分, 我們可以通過給定一個\(z\)然后獲得一個\(\hat{x}\), 這是很有用的東西, 但是我認為這種不是很合理, 因為除非\(g\)是可逆的, 那么\(p_{\theta}(x|z)= p _{\theta}(x|\hat{x})\) (當然, 別無選擇).
伯努利分布
此時\(\hat{x}=g(z)\)是\(x=1\)的概率, 則此時第二項的損失為
為(二分類)交叉熵損失.
高斯分布
一種簡單粗暴的, \(p(x|\hat{x})=\mathcal{N}(\hat{x},\sigma^2 I)\), 此時損失為類平方損失, 文中也有別的變換.
代碼
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(self, part2):
super(Loss, self).__init__()
self.part2 = part2
def forward(self, mu, sigma, real, fake, lam=1):
part1 = (1 + torch.log(sigma ** 2)
- mu ** 2 - sigma ** 2).sum() / 2
part2 = self.part2(fake, real)
return part1 + lam * part2