AAE對抗自編碼器


譯自:https://hjweide.github.io/adversarial-autoencoders

1.自編碼器AE作為生成模型

我們已經簡要提到過,編碼器輸出的屬性使我們能夠將輸入數據轉換為有用的表示形式。在使用變分自動編碼器的情況下,解碼器已受過訓練,可以從類似於我們選擇的先驗樣本的樣本中重建輸入。因此,我們可以從此先驗分布中采樣數據點,並將其饋送到解碼器中,以在原始數據空間中重建逼真的外觀數據點。

不幸的是,變分自動編碼器通常會在先驗分布的空間中留下一些區域,這些區域不會映射到數據中的實際樣本。對抗性自動編碼器旨在通過鼓勵編碼器的輸出完全填充先驗分布的空間來改善此情況,從而使解碼器能夠從先驗采樣的任何數據點生成逼真的樣本。對抗性自動編碼器通過使用兩個新組件,即鑒別器和生成器,來代替使用變分推理。接下來討論這些。

2.訓練更新過程

https://zhuanlan.zhihu.com/p/68903857

對抗自編碼器的網絡結構主要分成兩大部分:自編碼部分(上半部分)、GAN判別網絡(下半部分)。整個框架也就是GAN和AutoEncoder框架二者的結合。訓練過程分成兩個階段:首先是樣本重構階段,通過梯度下降更新自編碼器encoder部分、以及decoder的參數、使得重構損失函數最小化;然后是正則化約束階段,交替更新判別網絡參數和生成網絡(encoder部分)參數以此提高encoder部分混淆判別網絡的能力。

下面這張圖片似乎更加清晰:

圖片來自:https://towardsdatascience.com/a-wizards-guide-to-adversarial-autoencoders-part-2-exploring-latent-space-with-adversarial-2d53a6f8a4f9

上面鏈接中比較清楚地講解了AAE兩階段的訓練過程: 

首先是重建的部分,這部分和正常的AE沒有什么差別。下面是正則化部分:

 

首先,我們訓練判別器對編碼器輸出(z)和一些隨機輸入(z’,目標分布)進行分類。 例如,隨機輸入可以正態分布,平均值為0,標准差為5。

因此,如果我們傳入具有所需分布(真實值)的隨機輸入,則鑒別器應給我們輸出1;而當我們傳入編碼器輸出時,鑒別器應給我們輸出0(偽值)。 直觀地,編碼器的輸出和判別器的隨機輸入都應具有相同的size。
下一步將是強制編碼器輸出具有所需分布的隱碼。 為此,我們將編碼器輸出作為輸入連接到鑒別器:

 

 在AAE中編碼部分相當於GAN的生成器。

我們將判別器的權重固定為當前的權重(使它們無法訓練),並在判別器的輸出端將目標固定為1。 稍后,我們將圖像傳遞到編碼器,並確定判別器輸出,然后將其用於查找損失(交叉熵代價函數)。我們將僅通過編碼器權重進行反向傳播,這會導致編碼器學習所需的分布並產生具有該分布的輸出(將判別器目標固定為1會使得編碼器通過查看判別器權重來學習所需的分布 )。

4.GAN與VAE的區別

轉自:https://www.zhihu.com/question/317623081

一個本質區別就是loss的區別

VAE是pointwise loss 點匹配,一個典型的特征就是pointwise loss常常會脫離數據流形面,因此看起來生成的圖片會模糊;

GAN是分布匹配的loss,更能貼近流行面,看起來就會清晰;

但分布匹配的難度較大,一個例子就是經常發生mode collapse問題,小分布丟失,而pointwise loss就沒有這個問題,可以用於做初始化或做糾正,因此發展了一系列GAN+VAE的工作。

 

VAE希望通過一種 顯式(explicit)的方法找到一個概率密度,並通過最小化對數似函數的下限來得到最優解;GAN則是對抗的方式來尋找一種平衡, 不需要認為給定一個顯式的概率密度函數。

5.AAE的訓練例子

https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

首先定義模型和各部分優化器:

torch.manual_seed(10)
Q, P = Q_net() = Q_net(), P_net(0)     # Encoder&Decoder
D_gauss = D_net_gauss()                # Discriminator adversarial
# Set optimizators
#AE部分:
P_decoder = optim.Adam(P.parameters(), lr=gen_lr)
Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)

#GAN部分:
Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)#為Encoder又定義了一個優化器,作為生成器的更新
D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)

包括三個部分的優化:

AE部分編碼和解碼器的優化:

    z_sample = Q(X)
    X_sample = P(z_sample)
    recon_loss = F.binary_cross_entropy(X_sample + TINY, 
                                        X.resize(train_batch_size, X_dim) + TINY)
    recon_loss.backward()
    P_decoder.step()//Enocder
    Q_encoder.step()//Decoder

 判別器D的優化:

    # Compute discriminator outputs and loss
    D_real_gauss, D_fake_gauss = D_gauss(z_real_gauss), D_gauss(z_fake_gauss)
    D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))
    D_loss.backward()       # Backpropagate loss
    D_gauss_solver.step()   # D判別器的優化

生成器(即Encoder的優化):

# Generator
Q.train()   # Back to use dropout
z_fake_gauss = Q(X)
D_fake_gauss = D_gauss(z_fake_gauss)

G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))
G_loss.backward()
Q_generator.step()#優化Ecnoder,即生成器

 6.AAE示例代碼

https://github.com/bfarzin/pytorch_aae/blob/master/main_aae.py#L114

https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

https://github.com/shidilrzf/Adversarial-Autoencoders/blob/master/train.py#L147


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM