GAN2-訓練GAN時所遇到的問題及可能的解決方法


問題1,模式坍塌(Mode collapse )

  • 對模式崩潰產生原因的猜想:

    • GAN的學習目標是映射關系G:x➡y,這種單一域之間的對應關系是高度約束不足的,無法為分類器和判別其的訓練提供足夠的信息輸入。
    • 在這種情況下所優化得到的G可以將域X轉換為與Y分布相同的域Y',但是並不能確保單獨的輸入和輸出樣本x和y是以一種有意義的方式配對的——無限多種映射G(由訓練過程的隨機性產生)針對單獨的輸入x可能產生無限多種y(對G的約束依舊不足,只是保證了分布域上的一致,而且每一個分布域都是由多子域(比如多個類別)所組成的,如果只是將x轉換為Y中的其中一個子域(此時生成的y全部都是同一個類別),這樣依舊可以使得損失函數降到最低,從而使得G和D達到局部最優並使訓練停止,這也是模式崩潰產生的原因)。
    • 模式崩潰問題導致很難孤立地優化對抗性目標,經常發生所有輸入都映射到相同輸出並且優化無法取得進展的現象。(或許將每一個類別都單獨挑出來分別生成獨有的G是一種根治的方案,不過這樣的G也不要想着有什么泛化能力了)。
  • spectral collapse(譜崩潰)& spectral regularization(頻譜正則化)

    • 文獻:Spectral regularization for combating mode collapse in GANs

    • 代碼:https://github.com/max-liu-112/SRGANs-Spectral-Regularization-GANs-

    • 概念及定義:

      • 譜崩潰:當模式崩潰發生時,權重矩陣的奇異值急劇下降的現象稱為譜崩潰,作者發現模式崩潰和頻譜崩潰並存的現象普遍存在,而本文通過解決譜崩潰來解決模式崩潰問題。

      • 譜歸一化權重矩陣(spectral nornalized weight matrix)\(\bar{W}_{SN}(W)\);當模型沒有模式崩潰發生時,\(\bar{W}_{SN}(W)\)中大部分值接近1;而當模式崩潰發生時,\(\bar{W}_{SN}(W)\)中的值會急劇下降。(作者在文中做了一些實驗,說明了這一現象,但沒有從理論層面證明為什么會發生這一現象)

        \[\bar{W}_{SN}(W):=\frac{W}{\sigma(W)} \]

        其中\(\sigma(W)\)是D中權重矩陣的[譜范數](https://mathworld.wolfram.com/SpectralNorm.html#:~:text=Spectral Norm. The natural norm induced by the,root of the maximum eigenvalue of %2C i.e.%2C)(WolframMathWorld-一個神奇的網站),相當於權重矩陣的最大奇異值。

      • 權重矩陣\(W\)的奇異值分解(singular value decomposition)(原文中的公式)

        \[W=U\cdot\sum\cdot{V^T} \]

        \[U^TU=I \]

        \[V^TV=I \]

        \[\sum=[\begin{matrix}D&0\\0&0\end{matrix}] \]

        其中\(D=diag{\{\sigma_1,\sigma_2,\cdots,\sigma_r\}}\)

    • 解決方案:頻譜正則化通過補償頻譜分布避免頻譜崩潰,從而對D的權重矩陣施加約束(核心思想:防止D的權重矩陣W集中到一個方向上)。有兩種頻譜正則化方案,

      • 頻譜正則化(spectral regularization)

        • 靜態補償(static compensation):需要手動確定超參數\(i\),不易於應用。

          \[\Delta{D}=\left[\begin{matrix}\sigma_1-\sigma_1 & 0 & \cdots & \cdots & \cdots & 0 \\ 0 & \sigma_1-\sigma_2 & \cdots & \cdots & \cdots & 0 \\ \vdots & \cdots & \ddots & \cdots & \cdots & 0\\ \vdots & \cdots & \cdots & \sigma_1-\sigma_i & \cdots & 0\\ \vdots & \cdots & \cdots & \cdots & \ddots & 0\\ 0 & \cdots & \cdots & \cdots & \cdots & 0 \end{matrix}\right] \]

        • 動態補償(dynamic compensation):沒有需要手動確定的超參數,相比於靜態補償使用起來更方便。

          \[\Delta{D^T}=\left[\begin{matrix}0 & 0 & \cdots & 0 \\ 0 & \gamma_2^T\cdot{\sigma_1^T-\sigma_2^T} & \cdots & 0 \\ \vdots & \cdots & \ddots & 0\\ 0 & 0 & \cdots & \gamma_r^T\cdot{\sigma_1^T-\sigma_r^T} \end{matrix}\right] \]

          \(\Delta{D^T}\)是第\(T\)次迭代的補償矩陣,\(\gamma_j^T\)是第\(j\)個補償系數:

          \[\gamma_j^T=max(\frac{\sigma_j^1}{\sigma_1^1},\cdots,\frac{\sigma_j^t}{\sigma_1^t},\cdots,\frac{\sigma_j^T}{\sigma_1^T}),t=0,1,\cdots,T \]

          \(\sigma_j^t\)是第\(t\)次迭代的第\(j\)個奇異值。

      • 頻譜正則化的實現

        \[\Delta{W}=U\cdot{[\begin{matrix}\Delta{D}&0\\0&0\end{matrix}]}\cdot{V^T} \]

        \[\bar{W}_{SR}(W)=\frac{W+\Delta{W}}{\sigma(W)}=\bar{W}_{SN}(W)+\frac{\Delta{W}}{\sigma(W)} \]

  • implicit variational learning(隱式變分學習)

    • 文獻:VEEGAN: Reducing Mode Collapse in GANs using Implicit Variational Learning

    • 代碼:https://github.com/akashgit/VEEGAN/blob/master/VEEGAN_2D_RING.ipynb

    • 概念及定義:

      • 隱式變分原理(implicit variational principle):

        VEEGAN引入了一個額外的重構網絡(reconstructor network),將真實數據映射到高斯隨機噪聲,通過聯合訓練訓練生成器和重建器網絡鼓勵重建器網絡不僅將數據分布映射到高斯分布,而且還近似地反轉生成器的動作。

      • 如何理解使用隱式變分原理可以防止模式崩潰?

        • 觀察上圖:中部\(p(x)\)是由兩個高斯分布疊加而成的真實分布;底部\(p_0(z)\sim{N(0,1)}\)是生成器\(G_\gamma\)的輸入;頂部是將重構網\(F_\theta\)用於生成數據和真實數據的結果;由底部到中部的箭頭是生成器\(G_\gamma\)的動作;由中部到頂部的綠色箭頭是重構生成數據的動作,紫色箭頭是重構真實數據的動作。在圖中,生成器都只是捕獲了\(p(x)\)中其中一個高斯分布,圖a與圖b的區別在於重構網絡不同。
        • 圖a中\(F_\theta\)\(G_\gamma\)的逆,由於生成數據只包含真實數據的部分分布,\(F_\theta\)對真實數據中分布被丟失的那部分數據的處理結果不定,這也意味着其重構結果大概率與\(p_0(z)\)不匹配,這種不匹配可以作為模式崩潰的指標。
        • 圖b中\(F_\theta\)成功將真實數據重構回\(p_0(z)\),此時如果\(G_\gamma\)發生模式崩潰,\(F_\theta\)並不會將生成數據重構回\(p_0(z)\)(畢竟真實數據分布與生成數據分布存在差異),由此產生的懲罰信息提供了強大的\(G_\gamma\),\(F_\theta\)學習信息。
      • 文中提到了一個模式崩潰發生原因的猜想:目標函數提供的關於生成器網絡參數\(\gamma\)的唯一信息是由鑒別器網絡\(D_\omega\)介導的。(An intuition behind why mode collapse occurs is that the only information that the objective function provides about γ is mediated by the discriminator network Dω)

      • 重構網絡本質上是依據重構數據的差異反應生成數據和真實數據的差異,那為什么不直接度量生成數據和真實數據的分布差異呢?為什么必須要借助重構網絡呢?

    • 解決方案:

      • 重構損失

        \[\min_{\gamma,\theta}O_{entropy}(\gamma,\theta)=E[||z-F_{\theta}(G_\gamma(z))||_2^2]+H(Z,F_\theta(X))~~~~~~\tag{1} \]

        前半部分保證\(F_\theta\)\(G_\gamma\)的逆;后半部分保證對於真實數據,\(F_\theta\)的重構結果依舊是與\(p_0(z)\)相同的分布,使用交叉熵進行計算。

      • 為了便於計算,將重構損失進行如下轉換:

        重構網絡\(F_\theta(x)\)對應於分布\(p_{\theta}(z|x)\),樣本集合\(X\sim{p(x)}\)的平均重構數據為

        \[p_{\theta}(z)=\int{p_{\theta}(z|x)p(x)dx}~~~~~~\tag{2} \]

        根據交叉熵公式以及2式,\(H(Z,F_{\theta}(X))\)可寫作

        \[H(Z,F_{\theta}(X))=-\int{p_{0}(z)logp_{\theta}(z)dz}=-\int{p_0(z)}log\int{p(x)p_{\theta}(z|x)}dxdz~~~~~~\tag{3} \]

        \(p_\theta(z)=p_0(z)\)時交叉熵最小,為了使上式可計算(畢竟\(p(x)\)未知),引入變分分布\(q_\gamma(x|z)\)和Jensen不等式,有(推導看原文):

        \[-logp_\theta(z)=-log\int{p_\theta(z|x)p(x)\frac{q_\gamma(x|z)}{q_\gamma(x|z)}}dx\leq{\int{q_\gamma(x|z)log\frac{p_\theta(z|x)}{q_\gamma(x|z)}}}dx~~~~~~\tag{4} \]

        \[-\int{p_0(z)}logp_\theta(z)\leq{KL[q_\gamma(x|z)p_0(z)||p_\theta(z|x)p(x)]}-E[logp_0(z)]~~~~~~\tag{5} \]

        這里的\(q_\gamma(x|z)\)對應生成器,\(p_{\theta}(z|x)\)對應重構器,由1和5式可將優化目標轉化為(此為優化目標的上界):

        \[O_{entropy}(\gamma,\theta)=E[||z-F_{\theta}(G_\gamma(z))||_2^2]+KL[q_\gamma(x|z)p_0(z)||p_\theta(z|x)p(x)]-E[logp_0(z)]~~~~~~\tag{6} \]

        6式還是無法計算,因為\(q_\gamma(x|z)\)對應生成器,\(p_{\theta}(z|x)\)對應重構器,都是隱式表示,分布未知;樣本數據分布\(p(x)\)也是未知,這里假設訓練所得判別器\(D_\omega(x,z)\)滿足

        \[D_\omega(Z,X)=log\frac{q_\gamma(x|z)p_0(z)}{p_\theta(z|x)p(x)}\tag{7} \]

        並有

        \[\hat{O}(\omega,\gamma,\theta)=\frac{1}{N}\sum_{i=1}^{N}D_{\omega}(z^i,x^i_g)+\frac{1}{N}\sum_{i=1}^{N}d(z^i,x_g^i) \tag{8} \]

        其中\((z^i,x_g^i)\sim{p_0(x)q_\gamma(x|z)}\),優化目標最終化為:

        \[O_{LR}(\omega,\gamma,\theta)=-E_\gamma[log(\sigma(D_{\omega}(z,x)))]-E_{\theta}[log(1-\sigma({D_{\omega}(z,x)}))] \tag{9} \]

        訓練偽代碼:生成器\(\gamma\),重構器\(\theta\),判別器\(\omega\)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM