理解GAN對抗神經網絡的損失函數和訓練過程


GAN最不好理解的就是Loss函數的定義和訓練過程,這里用一段代碼來輔助理解,就能明白到底是怎么回事。其實GAN的損失函數並沒有特殊之處,就是常用的binary_crossentropy,關鍵在於訓練過程中存在兩個神經網絡和兩個損失函數。

np.random.seed(42)
tf.random.set_seed(42)

codings_size = 30

generator = keras.models.Sequential([
    keras.layers.Dense(100, activation="selu", input_shape=[codings_size]),
    keras.layers.Dense(150, activation="selu"),
    keras.layers.Dense(28 * 28, activation="sigmoid"),
    keras.layers.Reshape([28, 28])
])
discriminator = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(150, activation="selu"),
    keras.layers.Dense(100, activation="selu"),
    keras.layers.Dense(1, activation="sigmoid")
])
gan = keras.models.Sequential([generator, discriminator])

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

這里generator並不用compile,因為gan網絡已經compile了。具體原因見下文。

訓練過程的代碼如下

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))              # not shown in the book
        for X_batch in dataset:
            # phase 1 - training the discriminator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)
        plot_multiple_images(generated_images, 8)                     # not shown
        plt.show()                                                    # not shown

第一階段(discriminator訓練)

# phase 1 - training the discriminator
noise = tf.random.normal(shape=[batch_size, codings_size])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)

這個階段首先生成數量相同的真實圖片和假圖片,concat在一起,即X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)。然后是label,真圖片的label是1,假圖片的label是0。

然后是迅速階段,首先將discrinimator設置為可訓練,discriminator.trainable = True,然后開始階段。第一個階段的訓練過程只訓練discriminator,discriminator.train_on_batch(X_fake_and_real, y1),而不是整個GAN網絡gan

第二階段(generator訓練)

# phase 2 - training the generator
noise = tf.random.normal(shape=[batch_size, codings_size])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)

在第二階段首先生成假圖片,但是不再生成真圖片。把假圖片的label全部設置為1,並把discriminator的權重凍結,即discriminator.trainable = False。這一步很關鍵,應該這么理解:

前面第一階段的是discriminator的訓練,使真圖片的預測值盡量接近1,假圖片的預測值盡量接近0,以此來達到優化損失函數的目的。現在將discrinimator的權重凍結,網絡中輸入假圖片,並故意把label設置為1。

注意,在整個gan網絡中,從上向下的順序是先通過geneartor,再通過discriminator,即gan = keras.models.Sequential([generator, discriminator])。第二個階段將discrinimator凍結,並訓練網絡gan.train_on_batch(noise, y2)。如果generator生成的圖片足夠真實,經過discrinimator后label會盡可能接近1。由於故意把y2的label設置為1,所以如果genrator生成的圖片足夠真實,此時generator訓練已經達到最優狀態,不會大幅度更新權重;如果genrator生成的圖片不夠真實,經過discriminator之后,預測值會接近0,由於y2的label是1,相當於預測值不准確,這時候gan網絡的損失函數較大,generator會通過更新generator的權重來降低損失函數。

之后,重新回到第一階段訓練discriminator,然后第二階段訓練generator。假設整個GAN網絡達到理想狀態,這時候generator產生的假圖片,經過discriminator之后,預測值應該是0.5。假如這個值小於0.5,證明generator不是特別准確,在第二階段訓練過程中,generator的權重會被繼續更新。假如這個值大於0.5,證明discriminator不是特別准確,在第一階段訓練中,discriminator的權征會被繼續更新。

簡單說,對於一張generator生成的假圖片,discriminator會盡量把預測值拉下拉,generator會盡量把預測值往上扯,類似一個拔河的過程,最后達到均衡狀態,例如0.6, 0.4, 0.55, 0.45, 0.51, 0.49, 0.50。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM