生成對抗網絡(GAN)


  GAN的全稱是 Generative Adversarial Networks,中文名稱是生成對抗網絡。原始的GAN是一種無監督學習方法,巧妙的利用“博弈”的思想來學習生成式模型。

1 GAN的原理

  GAN的基本原理很簡單,其由兩個網絡組成,一個是生成網絡G(Generator) ,另外一個是判別網絡D(Discriminator)。它們的功能分別是:

  生成網絡G:負責生成圖片,它接收一個隨機的噪聲 $z$,通過該噪聲生成圖片,將生成的圖片記為 $G(z)$。

  判別網絡D:負責判別一張圖片是真實的圖片還是由G生成的假的圖片。其輸入是一張圖片 $x$ ,輸出是0, 1值,0代表圖片是由G生成的,1代表是真實圖片。

  在訓練過程中,生成網路G的目標是盡量生成真實的圖片去欺騙判別網絡D。而判別網絡D的目標就是盡量把G生成的圖片和真實的圖片區分開來。這樣G和D就構成了一個動態的博弈過程。這是GAN的基本思想。

  在最理想的狀態下,G可以生成足以“以假亂真”的圖片 $G(z)$。對於D來說,它難以判斷G生成的圖片究竟是不是真實的,因此 $D(G(z)) = 0.5$ (在這里我們輸入的真實圖片和生成的圖片是各一半的)。此時得到的生成網絡G就可以用來生成圖片。

 

2 GAN損失函數

  從數學的角度上來看GAN,假設用於訓練的真實圖片數據是 $x$,圖片數據的分布為 $p_{data}(x)$,生成網絡G需要去學習到真實數據分布 $p_{data}(x)$。噪聲 $z$ 的分布假設為$p_z(z)$,在這里 $p_z(z)$是已知的,而 $p_{data}(x)$ 是未知的。在理想的狀態下$G(z)$ 的分布應該是盡可能接近$p_{data}(x)$,G將已知分布的$z$ 變量映射到位置分布 $x$ 變量上。

  根據交叉熵損失,可以構造下面的損失函數:

  $ V(D,G) = E_{x~p_{data}(x)} [ln D(x)] + E_{z~p_z(z)} [ln(1-D(G(z)))] $

  其實從損失函數中可以看出和邏輯回歸的損失函數基本一樣,唯一不一樣的是負例的概率值為 $ 1-D(G(z))$。

  損失函數中加號的前一半是訓練數據中的真實樣本,后一半是從已知的噪聲分布中取的樣本。下面對這個損失函數詳細描述:

  1)整個式子有兩項構成。 $x$表示真實圖片,$z$表示輸入G網絡的噪聲,而$G(z)$ 表示G網絡生成的圖片。

  2)$D(x)$ 表示D網絡判斷真實圖片是否真實的概率 ,即 $P(y=1 | x)$。而$D(G(z))$ 是D網絡判斷$G$生成的圖片是否真實的概率。

  3)G的目的:G應該希望自己生成的圖片越真實越好。也就是說G希望 $D(G(z))$ 盡可能大,即$P(G(z) = 1 | x)$,這時 $V(D, G)$ 盡可能小。

  4)D的目的:D的能力越強,$D(x)$ 就應該越大,$D(G(x))$應該越小(即假的圖片都被識別為0)。因此D的目的和G的目的不同,D希望 $V(D, G)$ 越大越好。

 

3 GAN建模流程

  在實際訓練中,使用梯度下降法,對D和G交替做優化,具體步驟如下:

  1)從已知的噪聲分布 $p_z(z)$中選取一些樣本

    ${z_1, z_2, ......, z_m}$

  2)從訓練數據中選出同樣個數的真實圖片

    ${x_1, x_2, ......, x_m}$

  3)設判別器D的參數為 $\theta_d$,其損失函數的梯度為

    $ \nabla \frac{1}{m} \sum_{i=1}^m [lnD(x_i) + ln(1-D(G(Z_I)))] $

  4)設生成器G的參數為 $\theta_g$,其損失函數的梯度為

    $ \nabla \frac{1}{m} \sum_{i=1}^m [ln(1-D(G(Z_I)))] $

  在上面的步驟中,每更新一次D的參數,緊接着就更新一次G的參數,有時也可以在更新 $k$ 次D的參數,再更新一次G的參數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM