Auto-Encoding Variational Bayes


Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.

主要內容

自編碼, 通過引入Encoder和Decoder來估計聯合分布\(p(x,z)\), 其中\(z\)表示隱變量(我們也可以讓\(z\)為樣本標簽, 使得Encoder成為一個判別器).

在Decoder中我們建立聯合分布\(p_{\theta}(x,z)\)以估計\(p(x,z)\), 在Encoder中建立一個后驗分布\(q_{\phi}(z|x)\)去估計\(p_{\theta}(z|x)\), 然后極大似然:

\[\begin{array}{ll} \log p_{\theta}(x) &= \log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ \end{array}, \]

上式倆邊關於\(z\)在分布\(q_{\phi}(z)\)下求期望可得:

\[\begin{array}{ll} \log p_{\theta}(x) & = \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)}) \\ &= \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )+D_{KL}(q_{\phi}(z|x)\| p_{\theta}(z |x ))\\ & \ge \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) \end{array}. \]

既然KL散度非負, 我們極大似然\(\log p_{\theta}(x)\)可以退而求其次, 最大化\(\mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )\)(ELBO, 記為\(\mathcal{L}\)).

又(\(p_{\theta}(z)\)為認為給定的先驗分布)

\[\begin{array}{ll} \mathcal{L}(\theta, \phi; x) &= -D_{KL}(q_{\phi}(z|x)\|p_{\theta}(z))+\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)], \end{array} \]

我們接下來通過對Encoder和Decoder的一些構造進一步擴展上面倆項.

Encoder (損失part1)

Encoder 將\(x\rightarrow z\), 就相當於在\(q_{\phi}(z|x)\)中進行采樣, 但是如果是直接采樣的話, 就沒法利用梯度回傳進行訓練了, 這里需要一個重參化技巧.

我們假設\(q_{\phi}(z|x)\)為高斯密度函數, 即\(\mathcal{N}(\mu, \sigma^2 I)\).
注: 文中還提到了其他的一些可行假設.

我們構建一個神經網絡\(f\), 其輸入為樣本\(x\), 輸出為\((\mu, \log \sigma)\)(輸出\(\log \sigma\)是為了保證\(\sigma\)為正), 則

\[z= \mu + \epsilon \odot \sigma, \epsilon \sim \mathcal{N}(0, I), \]

其中\(\odot\)表示按元素相乘.
注: 我們可以該輸出為\((\mu, L)\)(\(L\)為三角矩陣, 且對角線元素非負), 而假設\(q_{\phi}(z|x)\)的分量不獨立, 其協方差函數為\(L^TL\), 則\((z=\mu + L \epsilon\)).

\(p_{\theta}(z)=\mathcal{N}(0, I)\), 我們可以顯示表達出:
在這里插入圖片描述
在這里插入圖片描述
在這里插入圖片描述

Decoder (損失part2)

現在我們需要處理的是第二項, 文中這地方因為直接設計\(p_{\theta}(x,z)\)不容易, 在我看來存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先設計一個網絡\(g_{\theta}(z)\), 其輸出為\(\hat{x}\), 然后假設\(p(x|\hat{x})\)的分布, 第二項就改為近似\(\mathbb{E}_{q_{\phi}(z|x)}p_{\theta}(x|\hat{x})\).

這么做的好處是顯而易見的, 因為Decoder部分, 我們可以通過給定一個\(z\)然后獲得一個\(\hat{x}\), 這是很有用的東西, 但是我認為這種不是很合理, 因為除非\(g\)是可逆的, 那么\(p_{\theta}(x|z)= p _{\theta}(x|\hat{x})\) (當然, 別無選擇).

伯努利分布

此時\(\hat{x}=g(z)\)\(x=1\)的概率, 則此時第二項的損失為

\[\log p(\mathbf{x}| \hat{\mathbf{x}})= \sum_{i=1} x_i \log \hat{x}_i + (1-x_i) \log (1- \hat{x}_i), \]

為(二分類)交叉熵損失.

高斯分布

一種簡單粗暴的, \(p(x|\hat{x})=\mathcal{N}(\hat{x},\sigma^2 I)\), 此時損失為類平方損失, 文中也有別的變換.

代碼

import torch
import torch.nn as nn


class Loss(nn.Module):
    def __init__(self, part2):
        super(Loss, self).__init__()
        self.part2 = part2

    def forward(self, mu, sigma, real, fake, lam=1):
        part1 = (1 + torch.log(sigma ** 2)
                 - mu ** 2 - sigma ** 2).sum() / 2
        part2 = self.part2(fake, real)
        return part1 + lam * part2


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM