生成性對抗神經網絡


GAN是什么?

生成對抗網絡(GANs)是當今計算機科學中最有趣的概念之一。
兩個模型通過對抗性過程同時訓練。
生成器(“藝術家”)學會創建看起來真實的圖像,而鑒別器(“藝術評論家”)學會區分真實圖像和贗品。


 
 

在訓練過程中,生成器逐漸變得更擅長創建看起來真實的圖像,而鑒別器則變得更擅長區分它們。
當鑒別器無法分辨真偽圖像時,該過程達到平衡。


 
 

下面的動畫展示了生成器在經過50個時代的訓練后生成的一系列圖像。
這些圖像一開始是隨機噪聲,隨着時間的推移越來越像手寫數字。

 
 

一、介紹

1.1 原理

這是一張關於GAN的流程圖


 
GAN

GAN主要的靈感來源是零和游戲在博弈論思想,應用於深學習神經網絡,是通過生成網絡G(發電機)和判別D(鑒頻器)網絡游戲不斷,從而使G學習數據分布,如果用在圖像生成訓練完成后,G可以從一個隨機數生成逼真的圖像。
G和D的主要功能是:

  • G是一個生成網絡,它接收一個隨機噪聲z(隨機數),通過噪聲生成圖像。

  • D是一個判斷圖像是否“真實”的網絡。它的輸入參數是x, x代表一張圖片,輸出D (x)代表x是一張真實圖片的概率。如果是1,代表100%真實的圖像,如果是0,代表不可能的圖像。

在訓練過程中,生成網絡G的目標是生成盡可能多的真實圖像來欺騙網絡D,而D的目標是試圖將G生成的假圖像與真實圖像區分開來。這樣,G和D構成一個動態的“博弈過程”,最終的均衡點為納什均衡點。

1.2 體系結構

通過對目標的優化,可以調整概率生成模型的參數,使概率分布與實際數據分布盡可能接近。

那么,如何定義適當的優化目標或損失呢?
在傳統的生成模型中,一般采用數據的似然作為優化目標,而GAN創新性地使用了另一個優化目標。

  • 首先,引入判別模型(常用模型包括支持向量機和多層神經網絡)。

  • 其次,其優化過程是在生成模型和判別模型之間找到納什均衡。

GAN建立的學習框架實際上是生成模型和判別模型之間的模擬博弈。
生成模型的目的是盡可能多地模擬、建模和學習真實數據的分布規律。
判別模型是判斷一個輸入數據是來自真實的數據分布還是生成的模型。
通過這兩個內部模型之間的持續競爭,提高了生成和區分這兩個模型的能力。

當一個模型具有很強的區分能力時。
如果生成的模型數據仍然存在混淆,不能正確判斷,那么我們認為生成的模型實際上已經了解了真實數據的分布情況。

1.3 GAN特性

特點:

  • low與傳統模式相比,有兩種不同的網絡,而不是單一的網絡,采用的是對抗訓練方法和訓練方式。

  • 更新信息中的低GAN梯度G來自判別式D,而不是來自樣本數據。

優勢:

  • low GAN是一個涌現模型,相對於其他生成模型(玻爾茲曼機和GSNs),它只通過反向傳播,不需要復雜的馬爾可夫鏈。

  • 與其它所有機型相比,GAN能生產出更清晰、真實的樣品

  • low GAN是一種無監督學習訓練,可廣泛應用於半監督學習和無監督學習領域。

  • 與變分自編碼器相比,GANs不引入任何確定性偏差,變分方法引入確定性偏差,因為它們優化了對數似然的下界而不是似然本身,這似乎導致VAEs生成的實例比GANs更加模糊。

  • 與VAE、GANs的變分下界相比較低,如果判別器訓練良好,則生成器可以學習完善訓練樣本分布。換句話說,GANs是逐漸一致的,但是VAE是有偏見的。

  • GAN——應用於一些場景,比如圖片風格轉換、超分辨率,圖像完成、噪聲去除,避免了損失函數設計的困難,只要有一個基准,直接鑒別器,其余的對抗訓練。

缺點:

  • 訓練GAN需要達到Nash均衡,有時可以通過梯度下降法實現,有時則不能。我們還沒有找到一個很好的方法來達到納什均衡,所以與VAE或PixelRNN相比,GAN的訓練是不穩定的,但我認為在實踐中它比訓練玻爾茲曼機更穩定。

  • GAN不適用於處理離散數據,如文本。

  • GAN存在訓練不穩定、梯度消失和模態崩潰等問題。

二、實現

加載和准備數據集,將使用MNIST數據集來訓練生成器和鑒別器。生成器將生成類似MNIST數據的手寫數字。

import tensorflow as tf def load_dataset(mnist_size, mnist_batch_size, cifar_size, cifar_batch_size,): """ load mnist and cifar10 dataset to shuffle. Args: mnist_size: mnist dataset size. mnist_batch_size: every train dataset of mnist. cifar_size: cifar10 dataset size. cifar_batch_size: every train dataset of cifar10. Returns: mnist dataset, cifar10 dataset """ # load mnist data (mnist_train_images, mnist_train_labels), (_, _) = tf.keras.datasets.mnist.load_data() # load cifar10 data (cifar_train_images, cifar_train_labels), (_, _) = tf.keras.datasets.cifar10.load_data() mnist_train_images = mnist_train_images.reshape(mnist_train_images.shape[0], 28, 28, 1).astype('float32') mnist_train_images = (mnist_train_images - 127.5) / 127.5 # Normalize the images to [-1, 1] cifar_train_images = cifar_train_images.reshape(cifar_train_images.shape[0], 32, 32, 3).astype('float32') cifar_train_images = (cifar_train_images - 127.5) / 127.5 # Normalize the images to [-1, 1] # Batch and shuffle the data mnist_train_dataset = tf.data.Dataset.from_tensor_slices(mnist_train_images) mnist_train_dataset = mnist_train_dataset.shuffle(mnist_size).batch(mnist_batch_size) cifar_train_dataset = tf.data.Dataset.from_tensor_slices(cifar_train_images) cifar_train_dataset = cifar_train_dataset.shuffle(cifar_size).batch(cifar_batch_size) return mnist_train_dataset, cifar_train_dataset 

2.2 創造模型文件

生成器和鑒別器都使用Keras Sequential API.

2.2.1 生成器模型

這里只對神經網絡體系結構使用最基本的全連接形式。
除第一層不使用歸一化外,其余層均由全連接的線性結構定義——>歸一化——>LeakReLU,具體參數如下所示。

import tensorflow as tf from tensorflow.python.keras import layers def make_generator_model(dataset='mnist'): """ implements generate. Args: dataset: mnist or cifar10 dataset. (default='mnist'). choice{'mnist', 'cifar'}. Returns: model. """ model = tf.keras.models.Sequential() model.add(layers.Dense(256, input_dim=100)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(512)) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(1024)) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU(alpha=0.2)) if dataset == 'mnist': model.add(layers.Dense(28 * 28 * 1, activation='tanh')) model.add(layers.Reshape((28, 28, 1))) elif dataset == 'cifar': model.add(layers.Dense(32 * 32 * 3, activation='tanh')) model.add(layers.Reshape((32, 32, 3))) return model 

2.2.2 鑒別器模型

鑒別器是一種基於cnn的圖像分類器。

import tensorflow as tf from tensorflow.python.keras import layers def make_discriminator_model(dataset='mnist'): """ implements discriminate. Args: dataset: mnist or cifar10 dataset. (default='mnist'). choice{'mnist', 'cifar'}. Returns: model. """ model = tf.keras.models.Sequential() if dataset == 'mnist': model.add(layers.Flatten(input_shape=[28, 28, 1])) elif dataset == 'cifar': model.add(layers.Flatten(input_shape=[32, 32, 3])) model.add(layers.Dense(1024)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(512)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(256)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(1, activation='sigmoid')) return model 

2.3 定義損失和優化器

2.3.1 為這兩個模型定義損失函數。

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) 

2.3.2 鑒別器損失函數

該方法量化了鑒別器對真偽圖像的識別能力。
它將鑒別器對真實圖像的預測與1的數組進行比較,將鑒別器對假(生成的)圖像的預測與0的數組進行比較。

def discriminator_loss(real_output, fake_output): """ This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s. Args: real_output: origin pic. fake_output: generate pic. Returns: real loss + fake loss """ real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss = real_loss + fake_loss return total_loss 

2.3.3 生成器損失函數

發電機的損耗量化了它欺騙鑒別器的能力。
直觀地說,如果生成器運行良好,鑒別器將把假圖像分類為真實圖像(或1)。
在這里,我們將把鑒別器對生成圖像的判斷與1的數組進行比較。

def generator_loss(fake_output): """ The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s. Args: fake_output: generate pic. Returns: loss """ return cross_entropy(tf.ones_like(fake_output), fake_output) 

2.3.4 優化

由於我們將分別訓練兩個網絡,因此鑒別器和生成器優化器是不同的。

def generator_optimizer(): """ The training generator optimizes the network. Returns: optim loss. """ return tf.keras.optimizers.Adam(lr=1e-4) def discriminator_optimizer(): """ The training discriminator optimizes the network. Returns: optim loss. """ return tf.keras.optimizers.Adam(lr=1e-4) 

2.4 保存訓練模型

本筆記本還演示了如何保存和恢復模型,這在長時間運行的訓練任務被中斷時是很有幫助的。

import os import tensorflow as tf def save_checkpoints(generator, discriminator, generator_optimizer, discriminator_optimizer, save_path): """ save gan model Args: generator: generate model. discriminator: discriminate model. generator_optimizer: generate optimizer func. discriminator_optimizer: discriminator optimizer func. save_path: save gan model dir path. Returns: checkpoint path """ checkpoint_dir = save_path checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator) return checkpoint_dir, checkpoint, checkpoint_prefix 

2.5 訓練

2.5.1 訓練設置

訓練循環從生成器接收隨機種子作為輸入開始。
種子是用來產生圖像的。
然后使用鑒別器對真實圖像(來自訓練集)和偽造圖像(由生成器生成)進行分類。
計算了每一種模型的損耗,並利用梯度對產生器和鑒別器進行了更新。

from dataset.load_dataset import load_dataset from network.generator import make_generator_model from network.discriminator import make_discriminator_model from util.loss_and_optim import generator_loss, generator_optimizer from util.loss_and_optim import discriminator_loss, discriminator_optimizer from util.save_checkpoints import save_checkpoints from util.generate_and_save_images import generate_and_save_images import tensorflow as tf import time import os import argparse parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='mnist', type=str, help='use dataset {mnist or cifar}.') parser.add_argument('--epochs', default=50, type=int, help='Epochs for training.') args = parser.parse_args() print(args) # define model save path save_path = 'training_checkpoint' # create dir if not os.path.exists(save_path): os.makedirs(save_path) # define random noise noise = tf.random.normal([16, 100]) # load dataset mnist_train_dataset, cifar_train_dataset = load_dataset(60000, 128, 50000, 64) # load network and optim paras generator = make_generator_model(args.dataset) generator_optimizer = generator_optimizer() discriminator = make_discriminator_model(args.dataset) discriminator_optimizer = discriminator_optimizer() checkpoint_dir, checkpoint, checkpoint_prefix = save_checkpoints(generator, discriminator, generator_optimizer, discriminator_optimizer, save_path) # This annotation causes the function to be "compiled". @tf.function def train_step(images): """ break it down into training steps. Args: images: input images. """ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients( zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients( zip(gradients_of_discriminator, discriminator.trainable_variables)) def train(dataset, epochs): """ train op Args: dataset: mnist dataset or cifar10 dataset. epochs: number of iterative training. """ for epoch in range(epochs): start = time.time() for image_batch in dataset: train_step(image_batch) # Produce images for the GIF as we go generate_and_save_images(generator, epoch + 1, noise, save_path) # Save the model every 15 epochs if (epoch + 1) % 15 == 0: checkpoint.save(file_prefix=checkpoint_prefix) print(f'Time for epoch {epoch+1} is {time.time()-start:.3f} sec.') # Generate after the final epoch generate_and_save_images(generator, epochs, noise, save_path) if __name__ == '__main__': if args.dataset == 'mnist': train(mnist_train_dataset, args.epochs) else: train(cifar_train_dataset, args.epochs) 

2.6 生成圖片並保存

from matplotlib import pyplot as plt def generate_and_save_images(model, epoch, test_input): # Notice `training` is set to False. # This is so all layers run in inference mode (batchnorm). predictions = model(test_input, training=False) fig = plt.figure(figsize=(4,4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i+1) plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) plt.show() plt.close(fig) 

三、常見問題

3.1 為什么GAN中的優化器不經常使用SGD

  • SGD易動搖,易使GAN訓練不穩定。

  • GAN的目的是在高維非凸參數空間中尋找納什均衡點。

GAN的納什均衡點是鞍點,而SGD只會找到局部最小值,因為SGD解決了尋找最小值的問題,而GAN是一個博弈問題。

3.2為什么GAN不適合處理文本數據

文本數據是離散的圖像數據相比,因為文本,通常需要地圖一個詞作為一個高維向量,最后預測輸出是一個熱向量,假設softmax輸出(0.2,0.3,0.1,0.2,0.15,0.05)就變成了onehot, 1, 0, 0, 0, 0(0),如果將softmax輸出(0.2,0.25,0.2,0.1,0.15,0.1),一個仍然是熱(0,1,0,0,0,0)。
因此,對於生成器,G輸出不同的結果,而D給出相同的判別結果,不能很好地將梯度更新信息傳遞給G,因此D最終輸出的判別是沒有意義的。

  • 此外,GAN的損失函數為JS散度,不適合測量不想相交的分布之間的距離。

3.3 GAN的一些技能培訓

  • 使用tanh begin將輸入規范化為(- 1,1),最后一級激活函數(異常)

  • 使用wassertein GAN的loss函數,

  • 如果你有標簽數據,嘗試使用標簽。有人建議使用倒裝標簽,並使用標簽平滑,單側標簽平滑或雙側標簽平滑

  • 使用小型批處理范數,如果不使用批處理范數,可以使用實例范數或權重范數

  • 避免使用RELU和池化層來降低稀疏梯度的可能性,可以使用leakrelu激活函數

  • 優化器盡可能選擇ADAM,學習速度不應該太大。初始的1e-4可以參考。此外,隨着培訓的進行,學習率可以不斷降低。

  • 在D的網絡層中加入高斯噪聲相當於一種正則化

3.4 模型崩潰原因

一般來說,GAN在訓練中並不穩定,效果也很差。
然而,即使延長培訓時間,也不能很好地改善。

具體原因可以解釋如下:
是針對甘用的訓練方法,G梯度更新自D,生成的G很好,所以D對我說什么。
具體來說,G將生成一個樣本,並將其交給D進行評估。
D將輸出生成的假樣本為真樣本的概率(0-1),這相當於告訴G生成的樣本有多真實。
G會根據這個反饋改進自身,提高D輸出的概率值。
但如果G生成樣本可能不是真的,但是D給予正確的評價,或者是G的結果生成的一些特征的識別D,然后G輸出會認為我是對的,所以我所以輸出D肯定也會給一個高評價,G實際上不是生成的,但他們是兩個自我欺騙,導致最終的結果缺少一些信息,特征。

四、GAN在生活中的應用

  • GAN本身就是一個生成模型,所以數據生成是最常見的,最常見的是圖像生成,常用的DCGAN WGAN開始,個人感覺在開始時最好也最簡單。

  • GAN本身也是一個無監督學習的模型。因此在無監督學習和半監督學習中得到了廣泛的應用。

  • GAN不僅在生成領域發揮作用,還在分類領域發揮作用。簡而言之,它是將識別器替換為分類器,執行多個分類任務,而生成器仍然執行生成任務並輔助分類器訓練。

  • GAN可以與強化學習相結合。seq-gan就是一個很好的例子。

  • 目前,GAN在圖像樣式轉換、圖像降噪和恢復、圖像超分辨率等方面都有很好的應用前景。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM