圖文詳解WGAN及其變體WGAN-GP並利用Tensorflow2實現WGAN與WGAN-GP


構建WGAN(Wasserstein GAN)

自GAN提出以來,許多論文嘗試通過使用啟發式方法(例如嘗試不同的網絡體系結構,超參數和優化器)來解決GAN訓練的不穩定性。隨着Wasserstein GAN(WGAN)的提出,這一問題的研究得到了重大突破。
WGAN緩解甚至消除了許多GAN訓練過程中存在的問題。相較於原始GAN的其根本的改進是對損失函數的修改。從理論上講,如果兩個分布不相交,則JS散度將不再是連續的,因此將不可微,從而導致梯度為零。 WGAN通過使用一個新的損失函數來解決此問題,該函數在任何地方都是連續且可微的!

Wasserstein loss介紹

對於原始GAN的目標函數,我們都已經耳熟能詳,在此簡單進行回顧:
m i n G m a x D V ( D , G ) = E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Exptata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, D D D表示鑒別器, G G G表示生成器, x x x表示真實數據, z z z表示潛在變量。
將上述形式進行轉換,可以得到如下值函數形式:
E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g D ( G ( z ) ) ] E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[logD(G(z))] Exptata(x)[logD(x)]+Ezpz(z)[logD(G(z))]
WGAN使用一種新的損失函數,稱為推土機距離或Wasserstein距離。它用於度量將一種分布轉換為另一種分布所需的距離或工作量。從數學上講,這是真實圖像與生成圖像之間每個聯合分布的最小距離,WGAN的值函數變為:
E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] E_{x\sim p_{data}(x)}[D(x)]-E_{z\sim p_z(z)}[D(G(z))] Expdata(x)[D(x)]Ezpz(z)[D(G(z))]
我們將使用此函數推導得到損失函數,首先第一項可以寫為:
− 1 N ∑ i = 1 N y i D ( x ) -\frac1N\sum_{i=1}^Ny_iD(x) N1i=1NyiD(x)
這是鑒別器輸出的平均值乘以-1。我們通過使用 y i y_i yi作為標簽,其中+1代表真實圖像,而-1代表虛假圖像。因此,我們可以將Wasserstein損失實現為TensorFlow Keras自定義損失函數,如下所示:

def wasserstein_loss(self, y_true, y_pred):
	w_loss = -tf.reduce_mean(y_true*y_pred)
	return w_loss

它旨在使真實圖像相對於偽圖像的得分最大化。因此,在WGAN中,鑒別器也被稱為評論家(critic)。
但是由於WGAN刪除了鑒別器的輸出中sigmoid激活函數。因此,評論家的預測是無限的,需要通1-Lipschitz進行約束。

1-Lipschitz約束的實現

Wasserstein損失中提到的數學假設是1-Lipschitz函數。我們說評論家D(x)如果滿足以下不等式,則為1-Lipschitz:
∣ D ( x 1 ) − D ( x 2 ) ∣ ≤ ∣ x 1 − x 2 ∣ |D(x_1)-D(x_2)|\leq|x_1-x_2| D(x1)D(x2)x1x2
對於兩個圖像 x 1 x_1 x1 x 2 x_2 x2,評論家的輸出差異的絕對值必須小於或等於其平均逐像素差的絕對值。換句話說,對於不同的圖像,無論是真實圖像還是偽造圖像,評論家的輸出不應有太大差異。當WGAN提出時,作者無法想到適當的實施方式來實現此不等式。因此,他們想出了一個辦法,就是將評論家的權重降低到一些很小的值。這樣,層的輸出以及最終評論家的輸出都被限制在一些較小的值上。在WGAN論文中,權重被限制在[-0.01,0.01]的范圍內。
權重裁剪可以通過兩種方式實現。一種方法是編寫一個自定義約束函數,並在實例化新層時使用它,如下所示:

class WeightsClip(tf.keras.constraints.Constraint):
    def __init__(self, min_value=-0.01, max_value=0.01):
        self.min_value = min_value
        self.max_value = max_value
    def __call__(self, w):
        return tf.clip_by_value(w, self.min, self.max_value)

然后,可以將函數傳遞給接受約束函數的層,如下所示:

model = tf.keras.Sequential(name='critics')        
model.add(Conv2D(16, 3, strides=2, padding='same',
				kernel_constraint=WeightsClip(),
				bias_constraint=WeightsClip()))
model.add(BatchNormalization(
beta_constraint=WeightsClip(),
gamma_constraint=WeightsClip()))

但是,在每個層創建過程中添加約束代碼會使代碼變得臃腫。由於我們不需要挑選要裁剪的層,因此可以使用循環讀取權重,裁剪后將其寫回,如下所示:
對於comment.layers中的層:

for layer in critic.layers:
    weights = layer.get_weights()
    weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
    layer.set_weights(weights)

訓練過程

在原始GAN理論中,應該在生成器之前對鑒別器進行訓練。但在實踐中,由於鑒別器能更快的訓練,因此鑒別器的梯度將逐漸消失。有了Wasserstein損失函數后,可以在任何地方推導梯度,將不必擔心評論家相較生成器過於強大。
因此,在WGAN中,對於生成器的每一個訓練步驟,評論家都會接受五次訓練。為了做到這一點,我們將評論家訓練步驟寫為一個單獨的函數,然后可以循環多次:

for _ in range(self.n_critic):
    real_images = next(data_generator)
    critic_loss = self.train_critic(real_images,     batch_size)

生成器的訓練步驟:

self.critic = self.build_critic()
self.critic.trainable = False
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss, optimizer = RMSprop(3e-4))
self.critic.trainable = True

在前面的代碼中,通過設置trainable = False凍結了評論者層,並將其鏈接到生成器以創建一個新模型並進行編譯。之后,我們可以將評論家設置為可訓練,這不會影響我們已經編譯的模型。
我們使用train_on_batch()API執行單個訓練步驟,該步驟將自動進行前向計算,損失計算,反向傳播和權重更新:

g_loss = self.model.train_on_batch(g_input,   real_labels)

下圖顯示了WGAN生成器體系結構:
生成器架構

下圖顯示了WGAN評論家體系結構:
評論家架構

盡管較原始GAN方面有所改進,但訓練WGAN十分困難,並且所產生的圖像質量並不比原始GAN更好。接下來,將實現WGAN的變體WGAN-GP,該變體訓練速度更快,並產生更清晰的圖像。

實現梯度懲罰(WGAN-GP)

正如WGAN作者所承認的那樣,權重裁剪並不是實施Lipschitz約束的理想方法。其有兩個缺點:網絡容量使用不足和梯度爆炸/消失。當我們裁剪權重時,我們也限制了評論家的學習能力。權重裁剪迫使網絡僅學習簡單特征。因此,神經網絡的容量變得未被充分利用。其次,裁剪值需要仔細調整。如果設置得太高,梯度會爆炸,從而違反了Lipschitz約束。如果設置得太低,則隨着網絡反向傳播,梯度將消失。同樣,權重裁剪會將梯度推到兩個極限值,如下圖所示:WGAN與WGAN-GP對比
因此,提出了梯度懲罰(GP)來代替權重裁剪以強制實施Lipschitz約束,如下所示:
G r a d i e n t   p e n a l t y = λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] Gradient\ penalty = \lambda E\hat x[(\lVert \nabla _{\hat x}D(\hat x) \rVert_2-1)^2] Gradient penalty=λEx^[(x^D(x^)21)2]
我們將查看方程式中的每個變量,並在代碼中實現它們。
我們通常使用 x x x表示真實圖像,但是現在方程式中有一個 x ^ \hat x x^ x ^ \hat x x^是真實圖像和偽圖像之間的逐點插值。從[0,1]的均勻分布中得出圖像比率(epsilon):

epsilon = tf.random.uniform((batch_size,1,1,1))
interpolates = epsilon*real_images + (1-epsilon)*fake_images

根據WGAN-GP論文,就我們的目的而言,我們可以這樣理解,因為梯度來自真實圖像和偽造圖像的混合,因此我們不需要分別計算真實和偽造圖像的損失。
∇ x ^ D ( x ^ ) \nabla _{\hat x}D(\hat x) x^D(x^)項是評論家輸出相對於插值的梯度。我們可以再次使用tf.GradientTape()來獲取梯度:

with tf.GradientTape() as gradient_tape:
	gradient_tape.watch(interpolates)
	critic_interpolates = self.critic(interpolates)
	gradient_d = gradient_tape.gradient(critic_interpolates, [interpolates])

下一步是計算L2范數:
∥ ∇ x ^ D ( x ^ ) ∥ 2 \lVert \nabla _{\hat x}D(\hat x) \rVert_2 x^D(x^)2
我們對每個值求平方,將它們加在一起,然后求平方根:

grad_loss = tf.square(grad)
grad_loss = tf.reduce_sum(grad_loss, axis=np.arange(1, len(grad_loss.shape)))
grad_loss = tf.sqrt(grad_loss)

在執行tf.reduce_sum()時,我們排除了軸上的第一維,因為該維是batch大小。懲罰旨在使梯度范數接近1,這是計算梯度損失的最后一步:

grad_loss = tf.reduce_mean(tf.square(grad_loss - 1))

等式中的 λ λ λ是梯度懲罰與其他評論家損失的比率,在本這里中設置為10。現在,我們將所有評論家損失和梯度懲罰添加到反向傳播並更新權重:

total_loss = loss_real + loss_fake + LAMBDA * grad_loss
gradients = total_tape.gradient(total_loss, self.critic.variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))

這就是需要添加到WGAN中以使其成為WGAN-GP的所有內容。不過,需要刪除以下部分:

  1. 權重裁剪
  2. 評論家中的批標准化

梯度懲罰是針對每個輸入獨立地對評論者的梯度范數進行懲罰。但是,批規范化會隨着批處理統計信息更改梯度。為避免此問題,批規范化從評論家中刪除。
評論家體系結構與WGAN相同,但不包括批規范化:
評論家架構

以下是經過訓練的WGAN-GP生成的樣本:
生成結果

它們看起來清晰漂亮,非常類似於Fashion-MNIST數據集中的樣本。訓練非常穩定,很快就收斂了!

完整代碼

# wgan_and_wgan_gp.py
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)

ds_train, ds_info = tfds.load('fashion_mnist', split='train',shuffle_files=True,with_info=True)
fig = tfds.show_examples(ds_train, ds_info)

batch_size = 64
image_shape = (32, 32, 1)

def preprocess(features):
    image = tf.image.resize(features['image'], image_shape[:2])    
    image = tf.cast(image, tf.float32)
    image = (image-127.5)/127.5
    return image

ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size, drop_remainder=True).repeat()

train_num = ds_info.splits['train'].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)

""" WGAN """
class WGAN():
    def __init__(self, input_shape):

        self.z_dim = 128
        self.input_shape = input_shape
        
        # losses
        self.loss_critic_real = {}
        self.loss_critic_fake = {}
        self.loss_critic = {}
        self.loss_generator = {}
        
        # critic
        self.n_critic = 5
        self.critic = self.build_critic()
        self.critic.trainable = False

        self.optimizer_critic = RMSprop(5e-5)

        # build generator pipeline with frozen critic
        self.generator = self.build_generator()
        critic_output = self.critic(self.generator.output)
        self.model = Model(self.generator.input, critic_output)
        self.model.compile(loss = self.wasserstein_loss,
                           optimizer =  RMSprop(5e-5))
        self.critic.trainable = True

        
    def wasserstein_loss(self, y_true, y_pred):

        w_loss = -tf.reduce_mean(y_true*y_pred)

        return w_loss

    def build_generator(self):

        DIM = 128
        model = tf.keras.Sequential(name='Generator') 

        model.add(layers.Input(shape=[self.z_dim])) 

        model.add(layers.Dense(4*4*4*DIM))
        model.add(layers.BatchNormalization()) 
        model.add(layers.ReLU())
        model.add(layers.Reshape((4,4,4*DIM))) 

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
        model.add(layers.Conv2D(2*DIM, 5, padding='same')) 
        model.add(layers.BatchNormalization()) 
        model.add(layers.ReLU())

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
        model.add(layers.Conv2D(DIM, 5, padding='same')) 
        model.add(layers.BatchNormalization()) 
        model.add(layers.ReLU())

        model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))       
        model.add(layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh')) 

        return model             
    
    def build_critic(self):

        DIM = 128
        model = tf.keras.Sequential(name='critics') 

        model.add(layers.Input(shape=self.input_shape)) 

        model.add(layers.Conv2D(1*DIM, 5, strides=2, padding='same'))
        model.add(layers.LeakyReLU(0.2))

        model.add(layers.Conv2D(2*DIM, 5, strides=2, padding='same'))
        model.add(layers.BatchNormalization()) 
        model.add(layers.LeakyReLU(0.2))

        model.add(layers.Conv2D(4*DIM, 5, strides=2, padding='same'))
        model.add(layers.BatchNormalization()) 
        model.add(layers.LeakyReLU(0.2))


        model.add(layers.Flatten()) 
        model.add(layers.Dense(1)) 

        return model     
    
 
    def train_critic(self, real_images, batch_size):

        real_labels = tf.ones(batch_size)
        fake_labels = -tf.ones(batch_size)
                  
        g_input = tf.random.normal((batch_size, self.z_dim))
        fake_images = self.generator.predict(g_input)
        
        with tf.GradientTape() as total_tape:
            
            # forward pass
            pred_fake = self.critic(fake_images)
            pred_real = self.critic(real_images)
            
            # calculate losses
            loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
            loss_real = self.wasserstein_loss(real_labels, pred_real)           

            # total loss
            total_loss = loss_fake + loss_real
            
            # apply gradients
            gradients = total_tape.gradient(total_loss, self.critic.trainable_variables)
            
            self.optimizer_critic.apply_gradients(zip(gradients, self.critic.trainable_variables))

        for layer in self.critic.layers: 
            weights = layer.get_weights() 
            weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
            layer.set_weights(weights) 

        return loss_fake, loss_real
                                                
    def train(self, data_generator, batch_size, steps, interval=200):

        val_g_input = tf.random.normal((batch_size, self.z_dim))
        real_labels = tf.ones(batch_size)

        for i in range(steps):
            for _ in range(self.n_critic):
                real_images = next(data_generator)
                loss_fake, loss_real = self.train_critic(real_images, batch_size)
                critic_loss = loss_fake + loss_real
                
            # train generator
            g_input = tf.random.normal((batch_size, self.z_dim))
            g_loss = self.model.train_on_batch(g_input, real_labels)
            
            self.loss_critic_real[i] = loss_real.numpy()
            self.loss_critic_fake[i] = loss_fake.numpy()
            self.loss_critic[i] = critic_loss.numpy()
            self.loss_generator[i] = g_loss

            if i%interval == 0:
                msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f} critic_real {:.4f}"\
                .format(i, g_loss, critic_loss, loss_fake, loss_real)
                print(msg)

                fake_images = self.generator.predict(val_g_input)
                self.plot_images(fake_images)
                self.plot_losses()

    def plot_images(self, images):   
        grid_row = 1
        grid_col = 8
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
        for row in range(grid_row):
            for col in range(grid_col):
                if self.input_shape[-1]==1:
                    axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
                else:
                    axarr[col].imshow(images[col]*0.5+0.5)
                axarr[col].axis('off') 
        plt.show()

    def plot_losses(self):
        fig, (ax1, ax2) = plt.subplots(2, sharex=True)
        fig.set_figwidth(10)
        fig.set_figheight(6)
        ax1.plot(list(self.loss_critic.values()), label='Critic loss', alpha=0.7)
        ax1.set_title("Critic loss")
        ax2.plot(list(self.loss_generator.values()), label='Generator loss', alpha=0.7)
        ax2.set_title("Generator loss")

        plt.xlabel('Steps')
        plt.show()

wgan = WGAN(image_shape)
wgan.generator.summary()

wgan.critic.summary()

wgan.train(iter(ds_train), batch_size, 2000, 100)

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)

wgan.generator.save_weights('./wgan_models/wgan_fashion_minist.weights')

""" WGAN_GP """
class WGAN_GP():
    def __init__(self, input_shape):

        self.z_dim = 128
        self.input_shape = input_shape

        # critic
        self.n_critic = 5
        self.penalty_const = 10
        self.critic = self.build_critic()
        self.critic.trainable = False

        self.optimizer_critic = Adam(1e-4, 0.5, 0.9)

        # build generator pipeline with frozen critic
        self.generator = self.build_generator()
        critic_output = self.critic(self.generator.output)
        self.model = Model(self.generator.input, critic_output)
        self.model.compile(loss=self.wasserstein_loss, optimizer=Adam(1e-4, 0.5, 0.9))

    def wasserstein_loss(self, y_true, y_pred):

        w_loss = -tf.reduce_mean(y_true*y_pred)

        return w_loss
    
    def build_generator(self):

        DIM = 128
        model = Sequential([
            layers.Input(shape=[self.z_dim]),
            
            layers.Dense(4*4*4*DIM),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.Reshape((4,4,4*DIM)),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(2*DIM, 5, padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(2*DIM, 5, padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.UpSampling2D((2,2), interpolation='bilinear'),
            layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh')
        ],name='Generator')

        return model
    
    def build_critic(self):

        DIM = 128
        model = Sequential([
            layers.Input(shape=self.input_shape),

            layers.Conv2D(1*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Conv2D(2*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Conv2D(4*DIM, 5, strides=2, padding='same', use_bias=False),
            layers.LeakyReLU(0.2),

            layers.Flatten(),
            layers.Dense(1)
        ], name='critics')

        return model
    
    def gradient_loss(self, grad):
        loss = tf.square(grad)
        loss = tf.reduce_sum(loss, axis=np.arange(1, len(loss.shape)))
        loss = tf.sqrt(loss)
        loss = tf.reduce_mean(tf.square(loss - 1))
        loss = self.penalty_const * loss
        return loss
    
    def train_critic(self, real_images, batch_size):
        real_labels = tf.ones(batch_size)
        fake_labels = -tf.ones(batch_size)

        g_input = tf.random.normal((batch_size, self.z_dim))
        fake_images = self.generator.predict(g_input)

        with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
            # forward pass
            pred_fake = self.critic(fake_images)
            pred_real = self.critic(real_images)

            # calculate losses
            loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
            loss_real = self.wasserstein_loss(real_labels, pred_real)

            # gradient penalty
            epsilon = tf.random.uniform((batch_size, 1, 1, 1))
            interpolates = epsilon * real_images + (1-epsilon) * fake_images
            gradient_tape.watch(interpolates)

            critic_interpolates = self.critic(interpolates)
            gradients_interpolates = gradient_tape.gradient(critic_interpolates, [interpolates])
            gradient_penalty = self.gradient_loss(gradients_interpolates)

            # total loss
            total_loss = loss_fake + loss_real + gradient_penalty

            # apply gradients
            gradients = total_tape.gradient(total_loss, self.critic.variables)

            self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))
        return loss_fake, loss_real, gradient_penalty
    
    def train(self, data_generator, batch_size, steps, interval=100):
        val_g_input = tf.random.normal((batch_size, self.z_dim))
        real_labels = tf.ones(batch_size)

        for i in range(steps):
            for _ in range(self.n_critic):
                real_images = next(data_generator)
                loss_fake, loss_real, gradient_penalty = self.train_critic(real_images, batch_size)
                critic_loss = loss_fake + loss_real + gradient_penalty
            # train generator
            g_input = tf.random.normal((batch_size, self.z_dim))
            g_loss = self.model.train_on_batch(g_input, real_labels)
            if i%interval == 0:
                msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f} critic_real {:.4f} penalty {:.4f}".format(i, g_loss, critic_loss, loss_fake, loss_real, gradient_penalty)
                print(msg)

                fake_images = self.generator.predict(val_g_input)
                self.plot_images(fake_images)

    def plot_images(self, images):   
        grid_row = 1
        grid_col = 8
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
        for row in range(grid_row):
            for col in range(grid_col):
                if self.input_shape[-1]==1:
                    axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
                else:
                    axarr[col].imshow(images[col]*0.5+0.5)
                axarr[col].axis('off') 
        plt.show()

wgan = WGAN_GP(image_shape)
wgan.train(iter(ds_train), batch_size, 5000, 100)

wgan.model.summary()

wgan.critic.summary()

z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM