Auto-Encoding Variational Bayes 公式推導及代碼


變分自動編碼器(VAE)用於生成模型,結合了深度模型以及靜態推理。簡單來說就是通過映射學習將一個高維數據,例如一幅圖片映射到低維空間Z。與標准自動編碼器不同的是,X和Z是隨機變量。所以可以這么理解,嘗試從P(X|Z)中去采樣出x,所以利用這個可以生成人臉,數字以及語句的生成。

1.模型


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
graph LR A[Data] -->|DNN_1| B(mu,std) B --> C[z] D(eps) --> C C-->|DNN_2|E[gen_data]

以上為模型的代碼和圖示,我們構建的為DNN_1和DNN_2模型,Data-->z為編碼器,z-->gen_data為解碼器。

2.損失函數

2.1 L

假設x-->z的真實概率分布為p(z|x),我們的模型的概率分布為q(z|x),那么損失可用他們的KL散度表示【\(\sum_z q(z|x) \log \frac{q(z|x)}{p(z|x)}\)】恆大於0,其值越大表示相似度越低,即損失越大

\[ \sum_z q(z|x) \log \frac{q(z|x)}{p(z|x)} =\sum_z q(z|x) \log (p(x)\frac{q(z|x)}{p(z|x)*p(x)}) =\sum_z q(z|x) \log (p(x)\frac{q(z|x)}{p(z,x)}) =\log p(x)+\sum_z q(z|x) \log (\frac{q(z|x)}{p(z,x)}) \]

由於p(x)為真實分布,是個固定分布,故最小化KL(q(z|x)||p(z|x))即最大化\(\sum_z q(z|x) \log (\frac{p(z,x)}{q(z|x)})\),我們設其為L

\[L=\sum_z q(z|x) \log (\frac{p(z)*p(x|z)}{q(z|x)})=\sum_z q(z|x) \log (\frac{p(z)}{q(z|x)})+\sum_z q(z|x) \log (p(x|z))=L_1+L_2 \]

我們假設z的先驗概率p(z)是N(0,1)分布,而我們的模型學到的q(z|x)是N(mu,std)分布

2.1.1 L1

\[L_1=E_{z\sim N(\mu,\sigma^2)} \log \frac{\frac{1}{\sqrt{2\pi}}e^{\frac{-z^2}{2}}}{\frac{1}{\sqrt{2\pi}\sigma}e^{\frac{(z-\mu)^2}{2\sigma^2}}}=E_{z\sim N(\mu,\sigma^2)} (\log \sigma-\frac{z^2}{2}+\frac{(z-\mu)^2}{2\sigma^2})= E_{z\sim N(\mu,\sigma^2)} (\log \sigma-\frac{z^2}{2}+\frac{z^2+\mu^2-2\mu z}{2\sigma^2}) \]

對於正態分布\(z\sim N(\mu,\sigma)\),\(E(z)=\mu,E(z^2)=D(z)+E(z)^2=\mu^2+\sigma^2\)
因此

\[L_1=\log \sigma -\frac{\mu^2+\sigma^2}{2}+\frac{\mu^2+\sigma^2+\mu^2-2\mu\mu}{2\sigma^2}=\log \sigma +\frac{1}{2}(1-\mu^2-\sigma^2) \]

2.1.2 L2

如果直接從 \(N(\mu,\sigma)\) 中采樣,那么采樣的結果是不可導的,我們通過重參數技巧來解決不能梯度下降的問題,即采樣$ \epsilon \sim N(0,1) $,用 $ \epsilon * \sigma +\mu $ 來代替 $ z \sim N(\mu,\sigma)$。用蒙特卡洛方法來計算L2。
這個是對應解碼器部分【交叉熵】

2.2 代碼

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM