變分自編碼器VAE的由來和簡單實現(PyTorch)


變分自編碼器VAE的由來和簡單實現(PyTorch)

​ 之前經常遇到變分自編碼器的概念(\(VAE\)),但是自己對於這個概念總是模模糊糊,今天就系統的對\(VAE\)進行一些整理和回顧。

VAE的由來

​ 假設有一個目標數據\(X=\{X_1,X_2,\cdots,X_n\}\),我們想生成一些數據,即生成\(\hat{X}=\{\hat{X_1},\hat{X_2},\cdots,\hat{X_n}\}\),其分布與\(X\)相同。

​ 但是實際上,這樣存在一些問題,第一是我們如何將生成的\(\hat{X}\)\(X\)一一對應,這就需要我們采用更為精巧的度量方式,即如何度量兩個分布之間的距離;第二是我們如何生成新的\(\hat{X}\),按照朴素的想法,我們可以構造一個函數\(G\),使得\(\hat{X}=G(Z)\) ,如果能構造出這個\(G\),我們就可以通過一個任意的\(Z\),來生成\(\hat{X}\) ,而這里的\(Z\),可以取一個已知的分布,比如正態分布。

目前的問題

​ 目前的問題轉化為了如何構造\(G\),以及如何檢驗我們生成的\(\hat{X}\)是否和\(X\)具有同分布。在\(GAN\)中,這里的\(G\)和分布的相似度衡量都用神經網絡搞定了,一個叫做\(generator\),一個叫做\(discriminator\),這二者互相拮抗,最終使得分布越來接近。

​ 而在我們目前的問題中,\(VAE\)提供了另外一種思路,沿着AutoEncoder的想法,AutoEncoder是通過\(encoder\)把image \(a\)編碼為vector,叫做\(latent{\ }represention\) ,再通過\(decoder\)\(latent{\ }space\)轉為\(\hat{a}\) ,\(\hat{a}\)\(a\)的重建圖像。

​ 但是AE針對每張圖片生成的\(latent{\ }code\)並沒有可解釋性,即sample兩個\(latent{\ } code\)之間的點輸入\(decoder\),得到的結果並不一定具有跟這兩個\(latent code\)相關的特征。為了解決這個問題,提出了VAE:不再采用vector來建模一個\(latent{\ }code\),而是利用一個帶有noise的高斯分布來表示。直觀的理解,在加入noise之后,就有機會將訓練時候train的\(latent{\ }code\)在其latent space下賦予一定的變化能力,使latent space變得更加連續,從而可以在其中采樣從而生成新的圖片。

​ 我們之前生成的\(Z=\{Z_1,Z_2,\cdots,Z_n\}\),現在不再單單生成一個\(Z\),而是生成兩個vector,分別記為\(M=\{{\mu_1},{\mu_2},\cdots\,{\mu_n}\}\),\(\Sigma=\{ {\sigma_1},{\sigma_2},\cdots,\sigma_n\}\),分別代表新生成latent code的高斯分布的均值和方差。在sample的時候就只需要根據從標准正態分布\(\mathcal{N}(0,1)\)中采樣一個\(e_i\),\(e_i\)來自於\(E=\{e_1,e_2,\cdots,e_n\}\),然后利用\(c_i=e_i*exp({\sigma_i})+\mu_i\)(\(reparameterization{\ }trick\)),就得到了我們所需的\(c_i\)\(c_i\)即組成我們需要的\(Z\)=\(\{c_1,c_2,\cdots,c_n\}\)

​ 這里一方面希望\(VAE\)能夠生成盡可能豐富的數據,因此訓練的時候希望在高斯分布中含有噪聲。另一方面優化的過程中會趨向於使圖像質量更好,因此當噪聲為0的時候退化為普通的\(AutoEncoder\),這種情況我們是不希望出現的。為了平衡這種trade-off,這里希望每個\(p(Z|X)\)能夠接近標准正態分布,但是另一方面網絡又趨於使輸入和輸出圖像更為接近,因此會使正態分布的方差向0的方向優化。經過這種對抗過程,最終就能產生具有一定可解釋性的\(decoder\),同時最終得到的\(Z\)的分布也會趨向於\(\mathcal{N}(0,1)\),可以表示為:

​ $$p(Z)=\sum_{X} p(Z \mid X) p(X)=\sum_{X} \mathcal{N}(0, 1) p(X)=\mathcal{N}(0, I) \sum_{X} p(X)=\mathcal{N}(0, 1)$$

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def loss_function_original(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

​ 這里的loss由兩部分組成,一部分是重建loss,一部分是使各個高斯分布趨近於標准高斯分布的loss(由KL散度推導得到)。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM