VAE論文學習


intractable棘手的,難處理的  posterior distributions后驗分布 directed probabilistic有向概率

approximate inference近似推理  multivariate Gaussian多元高斯  diagonal對角 maximum likelihood極大似然

參考:https://blog.csdn.net/yao52119471/article/details/84893634

 

VAE論文所在講的問題是:

我們現在就是想要訓練一個模型P(x),並求出其參數Θ:

 

通過極大似然估計求其參數

 

Variational Inference

在論文中P(x)模型會被拆分成兩部分,一部分由數據x生成潛在向量z,即pθ(z|X);一部分從z重新在重構數據x,即pθ(X|z)

實現過程則是希望能夠使用一個qΦ(z|X)模型去近似pθ(z|X),然后作為模型的Encoder;后半部分pθ(X|z)則作為Decoder,Φ/θ表示參數,實現一種同時學習識別模型參數φ和參數θ的生成模型的方法,推導過程為:

 

 

 

 

現在問題就在於怎么進行求導,因為現在模型已經不是一個完整的P(x) = pθ(z|X) + pθ(X|z),現在變成了P(x) = qΦ(z|X) + pθ(X|z),那么如果對Φ求導就會變成一個問題,因此論文中就提出了一個reparameterization trick方法:

 

 取樣於一個標准正態分布來采樣z,以此將qΦ(z|X) 和pθ(X|z)兩個子模型通過z連接在了一起

 

最終的目標函數為:

 

因此目標函數 = 輸入和輸出x求MSELoss - KL(qΦ(z|X) || pθ(z))

在論文上對式子最后的KL散度 -KL(qΦ(z|X) || pθ(z))的計算有簡化為:

 多維KL散度的推導可見:KL散度

 假設pθ(z)服從標准正態分布,采樣ε服從標准正態分布滿足該假設

 

 

 

 

簡單代碼實現:

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt



class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.relu(self.linear2(x))


class VAE(torch.nn.Module):
    latent_dim = 8

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self._enc_mu = torch.nn.Linear(100, 8)
        self._enc_log_sigma = torch.nn.Linear(100, 8)

    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc) #得到的值是loge(sigma)
        sigma = torch.exp(log_sigma) # = e^loge(sigma) = sigma
        #從均勻分布中取樣
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()

        self.z_mean = mu
        self.z_sigma = sigma

        return mu + sigma * Variable(std_z, requires_grad=False)  # Reparameterization trick

    def forward(self, state):
        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

# 計算KL散度的公式
def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)


if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 32

    transform = transforms.Compose(
        [transforms.ToTensor()])
    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
                                             shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(8, 100, input_dim)
    vae = VAE(encoder, decoder)

    criterion = nn.MSELoss()

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
            inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            optimizer.zero_grad()
            dec = vae(inputs)
            ll = latent_loss(vae.z_mean, vae.z_sigma)
            loss = criterion(dec, inputs) + ll
            loss.backward()
            optimizer.step()
            l = loss.data[0]
        print(epoch, l)

    plt.imshow(vae(inputs).data[0].numpy().reshape(28, 28), cmap='gray')
    plt.show(block=True)
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM