Pytorch入門之VAE


關於自編碼器的原理見另一篇博客 : 編碼器AE & VAE

這里談談對於變分自編碼器(Variational auto-encoder)即VAE的實現。

 

1. 稀疏編碼

首先介紹一下“稀疏編碼”這一概念。

       早期學者在黑白風景照片中可以提取到許多16*16像素的圖像碎片。而這些圖像碎片幾乎都可由64種正交的邊組合得到。而且組合出一張碎片所需的邊的數目很少,即稀疏的。同時在音頻中大多數聲音也可由幾種基本結構組合得到。這其實就是特征的稀疏表達。即使用少量的基本特征來組合更加高層抽象的特征。在神經網絡中即體現出前一層是未加工的像素,而后一層就是對這些像素的非線性組合。

       有監督情況下可以利用深層卷積網絡來提取特征,而自編碼器就是無監督情況下根據自身的高階特征編碼自己。自編碼器是輸入輸出相同的神經網絡。其特點是利用稀疏的高階特征來重構自己。一般而言自編碼器的中間隱層節點的數量要小於輸入節點的數量,即實現降維過程。因為對於少於輸入節點的隱藏層來說無法將輸入的全部信息保留,只能優先選擇部分重要的特征,而后利用這些特征來復原。此外我們可以給隱層的權重加上L2正則,正則項懲罰因子越大,接近於0的系數越多,從而特征更加稀疏!   

       關於自編碼器我們可以加入一些限制使其實現不同的功能,例如去噪自編碼(Denoising AutoEncoder)。輸入是加了噪聲的數據,而輸出是原始數據,在學習過程中,只有學到更魯棒、更頻繁的特征模式才能將噪聲略去,回復原始數據。如果自編碼器的隱層只有一層,那么原理類似於主成分分析PCA。

        HInton提出的DBN模型有多個隱含層,每個隱含層都是限制玻爾茲曼機RBM。DBN訓練時需先對每兩層間進行無監督的預訓練,這一過程實為一個多層的自編碼器,可以將每整個網絡的權重初始化到一個理想的分布。最后通過反向傳播算法調整模型權重,這個步驟會使用經過標注的信息來做監督性的分類訓練。當年DBN給訓練深度神經網絡提供了可能性,它解決了網絡過深帶來的深度彌散。簡言之:先用自編碼器的方法進行無監督的預訓練,提取特征並初始化權重,然后使用標注信息進行監督式的訓練。

 

2.  VAE工作流程

先看下圖:

        AE的工作其實是實現了    圖片->向量->圖片   這一過程。就是說給定一張圖片編碼后得到一個向量,然后將這一向量進行解碼后就得到了原始的圖片。這個解碼后的圖片和之前的原圖一樣嗎?不完全一樣。因為一般而言,如前所述是從低維隱層中恢復原圖。但是AE另我們現在能訓練任意多的圖片,如果我們把這些圖片的編碼向量存在來,那以后就能通過這些編碼向量來重構我們的圖像,稱之為標准自編碼器。可這還不夠,如果現在我隨機拿出一個很離譜的向量直接另其解碼,那解碼出來的東西十有八九是無意義的東西。

       所以我們希望AE編碼出的code符合一種分布(eg:高斯混合模型),那么我們就可以從這個高斯分布任意采樣出一個code,給這個code解碼那么就會生成一張原圖類似的圖。而這個強迫分布就是VAE與AE的不同之處了。VAE的編碼器輸出包括兩部分:m和σ。其中e是正態分布, c為編碼結果。m、e、σ、c的形狀一樣,都為(batch_size,latent_code_num) 。這個latent_code_num就相當於高斯混合分布的高斯數量。每個高斯都有自己的均值、方差。所以共有latent_code_num個均值、方差。

       接下來是VAE的損失函數:由兩部分的和組成(bce_loss、kld_loss)。bce_loss即為binary_cross_entropy(二分類交叉熵)損失,即用於衡量原圖與生成圖片的像素誤差。kld_loss即為KL-divergence(KL散度),用來衡量潛在變量的分布和單位高斯分布的差異。

 

 3. Pytorch實現 

  1 #!/usr/bin/env python3
  2 # -*- coding: utf-8 -*-
  3 """
  4 Created on Sat Mar 10 20:48:03 2018
  5 
  6 @author: lps
  7 """
  8 
  9 import torch
 10 import torch.nn as nn
 11 import torch.optim as optim
 12 import torch.nn.functional as F
 13 from torch.autograd import Variable
 14 from torchvision import transforms
 15 import torchvision.datasets as dst
 16 from torchvision.utils import save_image
 17 
 18 
 19 EPOCH = 15
 20 BATCH_SIZE = 64
 21 n = 2   # num_workers
 22 LATENT_CODE_NUM = 32   
 23 log_interval = 10
 24 
 25 
 26 transform=transforms.Compose([transforms.ToTensor()])
 27 data_train = dst.MNIST('MNIST_data/', train=True, transform=transform, download=False)
 28 data_test = dst.MNIST('MNIST_data/', train=False, transform=transform)
 29 train_loader = torch.utils.data.DataLoader(dataset=data_train, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
 30 test_loader = torch.utils.data.DataLoader(dataset=data_test, num_workers=n,batch_size=BATCH_SIZE, shuffle=True)
 31 
 32 
 33 class VAE(nn.Module):
 34       def __init__(self):
 35             super(VAE, self).__init__()
 36       
 37             self.encoder = nn.Sequential(
 38                   nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
 39                   nn.BatchNorm2d(64),
 40                   nn.LeakyReLU(0.2, inplace=True),
 41                   
 42                   nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
 43                   nn.BatchNorm2d(128),
 44                   nn.LeakyReLU(0.2, inplace=True),
 45                       
 46                   nn.Conv2d(128, 128, kernel_size=3 ,stride=1, padding=1),
 47                   nn.BatchNorm2d(128),
 48                   nn.LeakyReLU(0.2, inplace=True),                  
 49                   )
 50             
 51             self.fc11 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
 52             self.fc12 = nn.Linear(128 * 7 * 7, LATENT_CODE_NUM)
 53             self.fc2 = nn.Linear(LATENT_CODE_NUM, 128 * 7 * 7)
 54             
 55             self.decoder = nn.Sequential(                
 56                   nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
 57                   nn.ReLU(inplace=True),
 58                   
 59                   nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
 60                   nn.Sigmoid()
 61                   )
 62 
 63       def reparameterize(self, mu, logvar):
 64             eps = Variable(torch.randn(mu.size(0), mu.size(1))).cuda()
 65             z = mu + eps * torch.exp(logvar/2)            
 66             
 67             return z
 68       
 69       def forward(self, x):
 70              out1, out2 = self.encoder(x), self.encoder(x)  # batch_s, 8, 7, 7
 71              mu = self.fc11(out1.view(out1.size(0),-1))     # batch_s, latent
 72              logvar = self.fc12(out2.view(out2.size(0),-1)) # batch_s, latent
 73              z = self.reparameterize(mu, logvar)      # batch_s, latent      
 74              out3 = self.fc2(z).view(z.size(0), 128, 7, 7)    # batch_s, 8, 7, 7
 75              
 76              return self.decoder(out3), mu, logvar
 77 
 78 
 79 def loss_func(recon_x, x, mu, logvar):
 80       BCE = F.binary_cross_entropy(recon_x, x,  size_average=False)
 81       KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
 82       
 83       return BCE+KLD
 84 
 85 
 86 vae = VAE().cuda()
 87 optimizer =  optim.Adam(vae.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
 88 
 89 
 90 def train(EPOCH):
 91       vae.train()
 92       total_loss = 0
 93       for i, (data, _) in enumerate(train_loader, 0):
 94             data = Variable(data).cuda()
 95             optimizer.zero_grad()
 96             recon_x, mu, logvar = vae.forward(data)
 97             loss = loss_func(recon_x, data, mu, logvar)
 98             loss.backward()
 99             total_loss += loss.data[0]
100             optimizer.step()
101             
102             if i % log_interval == 0:
103                   sample = Variable(torch.randn(64, LATENT_CODE_NUM)).cuda()
104                   sample = vae.decoder(vae.fc2(sample).view(64, 128, 7, 7)).cpu()
105                   save_image(sample.data.view(64, 1, 28, 28),
106                'result/sample_' + str(epoch) + '.png')
107                   print('Train Epoch:{} -- [{}/{} ({:.0f}%)] -- Loss:{:.6f}'.format(
108                               epoch, i*len(data), len(train_loader.dataset), 
109                               100.*i/len(train_loader), loss.data[0]/len(data)))
110                   
111       print('====> Epoch: {} Average loss: {:.4f}'.format(
112           epoch, total_loss / len(train_loader.dataset)))
113       
114 for epoch in range(1, EPOCH):
115     train(epoch)
116       
main.py

編解碼器可由全連接或卷積網絡實現。這里采用CNN。結果如下:

 

       

       

         

 

 

參考 :

《Tensoflow 實戰》

Pytorch tutorial

Paper-Implementations

yunjey/pytorch-tutorial


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM