圖文詳解WGAN及其變體WGAN-GP並利用Tensorflow2實現WGAN與WGAN-GP
構建WGAN(Wasserstein GAN)
自GAN提出以來,許多論文嘗試通過使用啟發式方法(例如嘗試不同的網絡體系結構,超參數和優化器)來解決GAN訓練的不穩定性。隨着Wasserstein GAN(WGAN)的提出,這一問題的研究得到了重大突破。
WGAN緩解甚至消除了許多GAN訓練過程中存在的問題。相較於原始GAN的其根本的改進是對損失函數的修改。從理論上講,如果兩個分布不相交,則JS散度將不再是連續的,因此將不可微,從而導致梯度為零。 WGAN通過使用一個新的損失函數來解決此問題,該函數在任何地方都是連續且可微的!
Wasserstein loss介紹
對於原始GAN的目標函數,我們都已經耳熟能詳,在此簡單進行回顧:
m i n G m a x D V ( D , G ) = E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Ex∼ptata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中, D D D表示鑒別器, G G G表示生成器, x x x表示真實數據, z z z表示潛在變量。
將上述形式進行轉換,可以得到如下值函數形式:
E x ∼ p t a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g D ( G ( z ) ) ] E_{x\sim p_{tata}(x)}[logD(x)] +E_{z\sim p_z(z)}[logD(G(z))] Ex∼ptata(x)[logD(x)]+Ez∼pz(z)[logD(G(z))]
WGAN使用一種新的損失函數,稱為推土機距離或Wasserstein距離。它用於度量將一種分布轉換為另一種分布所需的距離或工作量。從數學上講,這是真實圖像與生成圖像之間每個聯合分布的最小距離,WGAN的值函數變為:
E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] E_{x\sim p_{data}(x)}[D(x)]-E_{z\sim p_z(z)}[D(G(z))] Ex∼pdata(x)[D(x)]−Ez∼pz(z)[D(G(z))]
我們將使用此函數推導得到損失函數,首先第一項可以寫為:
− 1 N ∑ i = 1 N y i D ( x ) -\frac1N\sum_{i=1}^Ny_iD(x) −N1i=1∑NyiD(x)
這是鑒別器輸出的平均值乘以-1。我們通過使用 y i y_i yi作為標簽,其中+1代表真實圖像,而-1代表虛假圖像。因此,我們可以將Wasserstein損失實現為TensorFlow Keras自定義損失函數,如下所示:
def wasserstein_loss(self, y_true, y_pred):
w_loss = -tf.reduce_mean(y_true*y_pred)
return w_loss
它旨在使真實圖像相對於偽圖像的得分最大化。因此,在WGAN中,鑒別器也被稱為評論家(critic)。
但是由於WGAN刪除了鑒別器的輸出中sigmoid激活函數。因此,評論家的預測是無限的,需要通1-Lipschitz進行約束。
1-Lipschitz約束的實現
Wasserstein損失中提到的數學假設是1-Lipschitz函數。我們說評論家D(x)如果滿足以下不等式,則為1-Lipschitz:
∣ D ( x 1 ) − D ( x 2 ) ∣ ≤ ∣ x 1 − x 2 ∣ |D(x_1)-D(x_2)|\leq|x_1-x_2| ∣D(x1)−D(x2)∣≤∣x1−x2∣
對於兩個圖像 x 1 x_1 x1和 x 2 x_2 x2,評論家的輸出差異的絕對值必須小於或等於其平均逐像素差的絕對值。換句話說,對於不同的圖像,無論是真實圖像還是偽造圖像,評論家的輸出不應有太大差異。當WGAN提出時,作者無法想到適當的實施方式來實現此不等式。因此,他們想出了一個辦法,就是將評論家的權重降低到一些很小的值。這樣,層的輸出以及最終評論家的輸出都被限制在一些較小的值上。在WGAN論文中,權重被限制在[-0.01,0.01]的范圍內。
權重裁剪可以通過兩種方式實現。一種方法是編寫一個自定義約束函數,並在實例化新層時使用它,如下所示:
class WeightsClip(tf.keras.constraints.Constraint):
def __init__(self, min_value=-0.01, max_value=0.01):
self.min_value = min_value
self.max_value = max_value
def __call__(self, w):
return tf.clip_by_value(w, self.min, self.max_value)
然后,可以將函數傳遞給接受約束函數的層,如下所示:
model = tf.keras.Sequential(name='critics')
model.add(Conv2D(16, 3, strides=2, padding='same',
kernel_constraint=WeightsClip(),
bias_constraint=WeightsClip()))
model.add(BatchNormalization(
beta_constraint=WeightsClip(),
gamma_constraint=WeightsClip()))
但是,在每個層創建過程中添加約束代碼會使代碼變得臃腫。由於我們不需要挑選要裁剪的層,因此可以使用循環讀取權重,裁剪后將其寫回,如下所示:
對於comment.layers中的層:
for layer in critic.layers:
weights = layer.get_weights()
weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
layer.set_weights(weights)
訓練過程
在原始GAN理論中,應該在生成器之前對鑒別器進行訓練。但在實踐中,由於鑒別器能更快的訓練,因此鑒別器的梯度將逐漸消失。有了Wasserstein損失函數后,可以在任何地方推導梯度,將不必擔心評論家相較生成器過於強大。
因此,在WGAN中,對於生成器的每一個訓練步驟,評論家都會接受五次訓練。為了做到這一點,我們將評論家訓練步驟寫為一個單獨的函數,然后可以循環多次:
for _ in range(self.n_critic):
real_images = next(data_generator)
critic_loss = self.train_critic(real_images, batch_size)
生成器的訓練步驟:
self.critic = self.build_critic()
self.critic.trainable = False
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss, optimizer = RMSprop(3e-4))
self.critic.trainable = True
在前面的代碼中,通過設置trainable = False
凍結了評論者層,並將其鏈接到生成器以創建一個新模型並進行編譯。之后,我們可以將評論家設置為可訓練,這不會影響我們已經編譯的模型。
我們使用train_on_batch()
API執行單個訓練步驟,該步驟將自動進行前向計算,損失計算,反向傳播和權重更新:
g_loss = self.model.train_on_batch(g_input, real_labels)
下圖顯示了WGAN生成器體系結構:
下圖顯示了WGAN評論家體系結構:
盡管較原始GAN方面有所改進,但訓練WGAN十分困難,並且所產生的圖像質量並不比原始GAN更好。接下來,將實現WGAN的變體WGAN-GP,該變體訓練速度更快,並產生更清晰的圖像。
實現梯度懲罰(WGAN-GP)
正如WGAN作者所承認的那樣,權重裁剪並不是實施Lipschitz約束的理想方法。其有兩個缺點:網絡容量使用不足和梯度爆炸/消失。當我們裁剪權重時,我們也限制了評論家的學習能力。權重裁剪迫使網絡僅學習簡單特征。因此,神經網絡的容量變得未被充分利用。其次,裁剪值需要仔細調整。如果設置得太高,梯度會爆炸,從而違反了Lipschitz約束。如果設置得太低,則隨着網絡反向傳播,梯度將消失。同樣,權重裁剪會將梯度推到兩個極限值,如下圖所示:
因此,提出了梯度懲罰(GP)來代替權重裁剪以強制實施Lipschitz約束,如下所示:
G r a d i e n t p e n a l t y = λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] Gradient\ penalty = \lambda E\hat x[(\lVert \nabla _{\hat x}D(\hat x) \rVert_2-1)^2] Gradient penalty=λEx^[(∥∇x^D(x^)∥2−1)2]
我們將查看方程式中的每個變量,並在代碼中實現它們。
我們通常使用 x x x表示真實圖像,但是現在方程式中有一個 x ^ \hat x x^。 x ^ \hat x x^是真實圖像和偽圖像之間的逐點插值。從[0,1]的均勻分布中得出圖像比率(epsilon):
epsilon = tf.random.uniform((batch_size,1,1,1))
interpolates = epsilon*real_images + (1-epsilon)*fake_images
根據WGAN-GP論文,就我們的目的而言,我們可以這樣理解,因為梯度來自真實圖像和偽造圖像的混合,因此我們不需要分別計算真實和偽造圖像的損失。
∇ x ^ D ( x ^ ) \nabla _{\hat x}D(\hat x) ∇x^D(x^)項是評論家輸出相對於插值的梯度。我們可以再次使用tf.GradientTape()
來獲取梯度:
with tf.GradientTape() as gradient_tape:
gradient_tape.watch(interpolates)
critic_interpolates = self.critic(interpolates)
gradient_d = gradient_tape.gradient(critic_interpolates, [interpolates])
下一步是計算L2范數:
∥ ∇ x ^ D ( x ^ ) ∥ 2 \lVert \nabla _{\hat x}D(\hat x) \rVert_2 ∥∇x^D(x^)∥2
我們對每個值求平方,將它們加在一起,然后求平方根:
grad_loss = tf.square(grad)
grad_loss = tf.reduce_sum(grad_loss, axis=np.arange(1, len(grad_loss.shape)))
grad_loss = tf.sqrt(grad_loss)
在執行tf.reduce_sum()
時,我們排除了軸上的第一維,因為該維是batch大小。懲罰旨在使梯度范數接近1,這是計算梯度損失的最后一步:
grad_loss = tf.reduce_mean(tf.square(grad_loss - 1))
等式中的 λ λ λ是梯度懲罰與其他評論家損失的比率,在本這里中設置為10。現在,我們將所有評論家損失和梯度懲罰添加到反向傳播並更新權重:
total_loss = loss_real + loss_fake + LAMBDA * grad_loss
gradients = total_tape.gradient(total_loss, self.critic.variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))
這就是需要添加到WGAN中以使其成為WGAN-GP的所有內容。不過,需要刪除以下部分:
- 權重裁剪
- 評論家中的批標准化
梯度懲罰是針對每個輸入獨立地對評論者的梯度范數進行懲罰。但是,批規范化會隨着批處理統計信息更改梯度。為避免此問題,批規范化從評論家中刪除。
評論家體系結構與WGAN相同,但不包括批規范化:
以下是經過訓練的WGAN-GP生成的樣本:
它們看起來清晰漂亮,非常類似於Fashion-MNIST數據集中的樣本。訓練非常穩定,很快就收斂了!
完整代碼
# wgan_and_wgan_gp.py
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)
ds_train, ds_info = tfds.load('fashion_mnist', split='train',shuffle_files=True,with_info=True)
fig = tfds.show_examples(ds_train, ds_info)
batch_size = 64
image_shape = (32, 32, 1)
def preprocess(features):
image = tf.image.resize(features['image'], image_shape[:2])
image = tf.cast(image, tf.float32)
image = (image-127.5)/127.5
return image
ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size, drop_remainder=True).repeat()
train_num = ds_info.splits['train'].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)
""" WGAN """
class WGAN():
def __init__(self, input_shape):
self.z_dim = 128
self.input_shape = input_shape
# losses
self.loss_critic_real = {}
self.loss_critic_fake = {}
self.loss_critic = {}
self.loss_generator = {}
# critic
self.n_critic = 5
self.critic = self.build_critic()
self.critic.trainable = False
self.optimizer_critic = RMSprop(5e-5)
# build generator pipeline with frozen critic
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss,
optimizer = RMSprop(5e-5))
self.critic.trainable = True
def wasserstein_loss(self, y_true, y_pred):
w_loss = -tf.reduce_mean(y_true*y_pred)
return w_loss
def build_generator(self):
DIM = 128
model = tf.keras.Sequential(name='Generator')
model.add(layers.Input(shape=[self.z_dim]))
model.add(layers.Dense(4*4*4*DIM))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.Reshape((4,4,4*DIM)))
model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
model.add(layers.Conv2D(2*DIM, 5, padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
model.add(layers.Conv2D(DIM, 5, padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.UpSampling2D((2,2), interpolation="bilinear"))
model.add(layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh'))
return model
def build_critic(self):
DIM = 128
model = tf.keras.Sequential(name='critics')
model.add(layers.Input(shape=self.input_shape))
model.add(layers.Conv2D(1*DIM, 5, strides=2, padding='same'))
model.add(layers.LeakyReLU(0.2))
model.add(layers.Conv2D(2*DIM, 5, strides=2, padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(layers.Conv2D(4*DIM, 5, strides=2, padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
def train_critic(self, real_images, batch_size):
real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)
g_input = tf.random.normal((batch_size, self.z_dim))
fake_images = self.generator.predict(g_input)
with tf.GradientTape() as total_tape:
# forward pass
pred_fake = self.critic(fake_images)
pred_real = self.critic(real_images)
# calculate losses
loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
loss_real = self.wasserstein_loss(real_labels, pred_real)
# total loss
total_loss = loss_fake + loss_real
# apply gradients
gradients = total_tape.gradient(total_loss, self.critic.trainable_variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.trainable_variables))
for layer in self.critic.layers:
weights = layer.get_weights()
weights = [tf.clip_by_value(w, -0.01, 0.01) for w in weights]
layer.set_weights(weights)
return loss_fake, loss_real
def train(self, data_generator, batch_size, steps, interval=200):
val_g_input = tf.random.normal((batch_size, self.z_dim))
real_labels = tf.ones(batch_size)
for i in range(steps):
for _ in range(self.n_critic):
real_images = next(data_generator)
loss_fake, loss_real = self.train_critic(real_images, batch_size)
critic_loss = loss_fake + loss_real
# train generator
g_input = tf.random.normal((batch_size, self.z_dim))
g_loss = self.model.train_on_batch(g_input, real_labels)
self.loss_critic_real[i] = loss_real.numpy()
self.loss_critic_fake[i] = loss_fake.numpy()
self.loss_critic[i] = critic_loss.numpy()
self.loss_generator[i] = g_loss
if i%interval == 0:
msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f} critic_real {:.4f}"\
.format(i, g_loss, critic_loss, loss_fake, loss_real)
print(msg)
fake_images = self.generator.predict(val_g_input)
self.plot_images(fake_images)
self.plot_losses()
def plot_images(self, images):
grid_row = 1
grid_col = 8
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
for row in range(grid_row):
for col in range(grid_col):
if self.input_shape[-1]==1:
axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
else:
axarr[col].imshow(images[col]*0.5+0.5)
axarr[col].axis('off')
plt.show()
def plot_losses(self):
fig, (ax1, ax2) = plt.subplots(2, sharex=True)
fig.set_figwidth(10)
fig.set_figheight(6)
ax1.plot(list(self.loss_critic.values()), label='Critic loss', alpha=0.7)
ax1.set_title("Critic loss")
ax2.plot(list(self.loss_generator.values()), label='Generator loss', alpha=0.7)
ax2.set_title("Generator loss")
plt.xlabel('Steps')
plt.show()
wgan = WGAN(image_shape)
wgan.generator.summary()
wgan.critic.summary()
wgan.train(iter(ds_train), batch_size, 2000, 100)
z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)
wgan.generator.save_weights('./wgan_models/wgan_fashion_minist.weights')
""" WGAN_GP """
class WGAN_GP():
def __init__(self, input_shape):
self.z_dim = 128
self.input_shape = input_shape
# critic
self.n_critic = 5
self.penalty_const = 10
self.critic = self.build_critic()
self.critic.trainable = False
self.optimizer_critic = Adam(1e-4, 0.5, 0.9)
# build generator pipeline with frozen critic
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss=self.wasserstein_loss, optimizer=Adam(1e-4, 0.5, 0.9))
def wasserstein_loss(self, y_true, y_pred):
w_loss = -tf.reduce_mean(y_true*y_pred)
return w_loss
def build_generator(self):
DIM = 128
model = Sequential([
layers.Input(shape=[self.z_dim]),
layers.Dense(4*4*4*DIM),
layers.BatchNormalization(),
layers.ReLU(),
layers.Reshape((4,4,4*DIM)),
layers.UpSampling2D((2,2), interpolation='bilinear'),
layers.Conv2D(2*DIM, 5, padding='same'),
layers.BatchNormalization(),
layers.ReLU(),
layers.UpSampling2D((2,2), interpolation='bilinear'),
layers.Conv2D(2*DIM, 5, padding='same'),
layers.BatchNormalization(),
layers.ReLU(),
layers.UpSampling2D((2,2), interpolation='bilinear'),
layers.Conv2D(image_shape[-1], 5, padding='same', activation='tanh')
],name='Generator')
return model
def build_critic(self):
DIM = 128
model = Sequential([
layers.Input(shape=self.input_shape),
layers.Conv2D(1*DIM, 5, strides=2, padding='same', use_bias=False),
layers.LeakyReLU(0.2),
layers.Conv2D(2*DIM, 5, strides=2, padding='same', use_bias=False),
layers.LeakyReLU(0.2),
layers.Conv2D(4*DIM, 5, strides=2, padding='same', use_bias=False),
layers.LeakyReLU(0.2),
layers.Flatten(),
layers.Dense(1)
], name='critics')
return model
def gradient_loss(self, grad):
loss = tf.square(grad)
loss = tf.reduce_sum(loss, axis=np.arange(1, len(loss.shape)))
loss = tf.sqrt(loss)
loss = tf.reduce_mean(tf.square(loss - 1))
loss = self.penalty_const * loss
return loss
def train_critic(self, real_images, batch_size):
real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)
g_input = tf.random.normal((batch_size, self.z_dim))
fake_images = self.generator.predict(g_input)
with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
# forward pass
pred_fake = self.critic(fake_images)
pred_real = self.critic(real_images)
# calculate losses
loss_fake = self.wasserstein_loss(fake_labels, pred_fake)
loss_real = self.wasserstein_loss(real_labels, pred_real)
# gradient penalty
epsilon = tf.random.uniform((batch_size, 1, 1, 1))
interpolates = epsilon * real_images + (1-epsilon) * fake_images
gradient_tape.watch(interpolates)
critic_interpolates = self.critic(interpolates)
gradients_interpolates = gradient_tape.gradient(critic_interpolates, [interpolates])
gradient_penalty = self.gradient_loss(gradients_interpolates)
# total loss
total_loss = loss_fake + loss_real + gradient_penalty
# apply gradients
gradients = total_tape.gradient(total_loss, self.critic.variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))
return loss_fake, loss_real, gradient_penalty
def train(self, data_generator, batch_size, steps, interval=100):
val_g_input = tf.random.normal((batch_size, self.z_dim))
real_labels = tf.ones(batch_size)
for i in range(steps):
for _ in range(self.n_critic):
real_images = next(data_generator)
loss_fake, loss_real, gradient_penalty = self.train_critic(real_images, batch_size)
critic_loss = loss_fake + loss_real + gradient_penalty
# train generator
g_input = tf.random.normal((batch_size, self.z_dim))
g_loss = self.model.train_on_batch(g_input, real_labels)
if i%interval == 0:
msg = "Step {}: g_loss {:.4f} critic_loss {:.4f} critic fake {:.4f} critic_real {:.4f} penalty {:.4f}".format(i, g_loss, critic_loss, loss_fake, loss_real, gradient_penalty)
print(msg)
fake_images = self.generator.predict(val_g_input)
self.plot_images(fake_images)
def plot_images(self, images):
grid_row = 1
grid_col = 8
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2.5, grid_row*2.5))
for row in range(grid_row):
for col in range(grid_col):
if self.input_shape[-1]==1:
axarr[col].imshow(images[col,:,:,0]*0.5+0.5, cmap='gray')
else:
axarr[col].imshow(images[col]*0.5+0.5)
axarr[col].axis('off')
plt.show()
wgan = WGAN_GP(image_shape)
wgan.train(iter(ds_train), batch_size, 5000, 100)
wgan.model.summary()
wgan.critic.summary()
z = tf.random.normal((8, 128))
generated_images = wgan.generator.predict(z)
wgan.plot_images(generated_images)