生成模型——自回歸模型詳解與PixelCNN構建
自回歸模型(Autoregressive models)
深度神經網絡生成算法主要分為三類:
- 生成對抗網絡(Generative Adversarial Network, GAN)
- 可變自動編碼器(Variational Autoencoder, VAE)
- 自回歸模型(Autoregressive models)
VAE已經在《變分自編碼器(VAE)原理與實現(tensorflow2.x)》中進行了介紹。GAN的詳細信息參考《深度卷積生成對抗網絡(DCGAN)原理與實現(采用Tensorflow2.x)》。將在這里介紹鮮為人知的自回歸模型,盡管自回歸在圖像生成中並不常見,但自回歸仍然是研究的活躍領域,DeepMind的WaveNet使用自回歸來生成逼真的音頻。在本文中,將介紹自回歸模型並構建PixelCNN模型。
簡介
Autoregressive中的“Auto”意味着自我(self
),而機器學習術語的回歸(regress
)意味着預測新的值。將它們放在一起,自回歸意味着我們使用模型基於模型的過去數據點來預測新數據點。
設圖像的概率分布是 p ( x ) p(x) p(x)是像素的聯合概率分布 p ( x 1 , x 2 , … x n ) p(x_1, x_2, …x_n) p(x1,x2,…xn),由於高維數很難建模。在這里,我們假設一個像素的值僅取決於它之前的像素的值。換句話說,當前像素僅以其前一像素為條件,即 p ( x i ) = p ( x i ∣ x i − 1 ) p ( x i − 1 ) p(x_i) = p(x_i | x_{i-1}) p(x_{i-1}) p(xi)=p(xi∣xi−1)p(xi−1),我們就可以將聯合概率近似為條件概率的乘積:
p ( x ) = p ( x n , x n − 1 , … , x 2 , x 1 ) p(x) = p(x_n, x_{n-1}, …, x_2, x_1) p(x)=p(xn,xn−1,…,x2,x1)
p ( x ) = p ( x n ∣ x n − 1 ) . . . p ( x 3 ∣ x 2 ) p ( x 2 ∣ x 1 ) p ( x 1 ) p(x) = p(x_n | x_{n-1})... p(x_3 | x_2) p(x_2 | x_1) p(x_1) p(x)=p(xn∣xn−1)...p(x3∣x2)p(x2∣x1)p(x1)
舉一個具體的例子,假設在圖像的中心附近包含一個紅色的蘋果,並且該蘋果被綠葉包圍,在此情況下,假設僅存在兩種可能的顏色:紅色和綠色。 x 1 x_1 x1是左上像素,所以 p ( x 1 ) p(x_1) p(x1)表示左上像素是綠色還是紅色的概率。如果 x 1 x_1 x1為綠色,則其右邊 p ( x 2 ) p(x_2) p(x2)的像素也可能也為綠色,因為它可能會有更多的葉子。但是,盡管可能性較小,但它也可能是紅色的。
繼續進行計算,我們最終將得到紅色像素。從那個像素開始,接下來的幾個像素也很可能也是紅色的,這比必須同時考慮所有像素要簡單得多。
PixelRNN
PixelRNN由DeepMind於2016年提出。正如名稱RNN(Recurrent Neural Network, 遞歸神經網絡)所暗示的那樣,該模型使用一種稱為長短期記憶(LSTM)的RNN來學習圖像的分布。它在LSTM中的一個步驟中一次讀取圖像的一行,並使用一維卷積層對其進行處理,然后將激活信息饋送到后續層中以預測該行的像素。
由於LSTM運行緩慢,因此需要花費很長時間來訓練和生成樣本。因此,我們不會對其進行過多的研究,而將注意力轉移到同一論文中提出的一種變體——PixelCNN。
使用TensorFlow 2構建PixelCNN模型
PixelCNN僅由卷積層組成,使其比PixelRNN快得多。在這里,我們將為使用MNIST數據集訓練一個簡單的PixelCNN模型。
輸入和標簽
MNIST由28 x 28 x 1灰度數字手寫數字組成。它只有一個通道:
在本實驗中,通過將圖像轉換為二進制數據來簡化問題:0代表黑色,1代表白色:
def binarize(image, label):
image = tf.cast(image, tf.float32)
image = tf.math.round(image/255.)
return image, tf.cast(image, tf.int32)
該函數需要兩個輸入——圖像和標簽。該函數的前兩行將圖像轉換為二進制float32格式,即0.0或1.0。並且,我們將二進制圖像轉換為整數並返回,以遵循使用整數作為標簽的慣例而已。返回的數據,將作為網絡訓練的輸入和標簽,都是28 x 28 x 1的二進制MNIST圖像,它們僅在數據類型上有所不同。
掩膜
與PixelRNN逐行讀取不同,PixelCNN在圖像中從左到右,從上到下滑動卷積核。當執行卷積以預測當前像素時,傳統的卷積核能夠看到當前輸入像素及其周圍的像素,其中包括當前像素之后的像素信息,這與在簡介部分的條件概率假設相悖。
為了避免這種情況,我們需要確保CNN在預測輸出像素 x i x_i xi時不會看到輸入像素 x i x_i xi。
這是通過使用掩膜卷積來實現的,其中在執行卷積之前將掩膜應用於卷積核權重。下圖顯示了一個7 x 7卷積核的掩膜,其中從中心開始的權重為0。這會阻止CNN看到它正在預測的像素(卷積核的中心)以及所有之后的像素。這稱為A型掩膜
,僅應用於輸入層。
由於中心像素在第一層中被遮擋,因此我們不再需要在后面的層中隱藏中心要素。實際上,我們需要將卷積核中心設置為1,以使其能夠讀取先前層的特征,這稱為B型掩膜
。
實現自定義層
現在,我們將為掩膜卷積創建一個自定義層。我們可以使用從基類tf.keras.layers.Layer繼承的子類在TensorFlow2.x
中創建自定義層,以便將能夠像使用其他Keras層一樣使用它。以下是自定義層類的基本結構:
class MaskedConv2D(tf.keras.layers.Layer):
def __init__(self):
...
def build(self, input_shape):
...
def call(self, inputs):
...
return output
build()將輸入張量的形狀作為參數,我們將使用此信息來確保創建正確形狀的變量。構建圖層時,此函數僅運行一次。我們可以通過聲明不可訓練的變量或常量來創建掩碼,以使TensorFlow知道它不需要梯度來反向傳播:
def build(self, input_shape):
self.w = self.add_weight(shape=[self.kernel,
self.kernel,
input_shape[-1],
self.filters],
initializer='glorot_normal',
trainable=True)
self.b = self.add_weight(shape=(self.filters,),
initializer='zeros',
trainable=True)
mask = np.ones(self.kernel**2, dtype=np.float32)
center = len(mask)//2
mask[center+1:] = 0
if self.mask_type == 'A':
mask[center] = 0
mask = mask.reshape((self.kernel, self.kernel, 1, 1))
self.mask = tf.constant(mask, dtype='float32')
call()用來執行前向傳遞。在掩膜卷積層中,在使用低級tf.nn API
執行卷積之前,我們將權重乘以掩碼后將下半部分的值設為零:
def call(self, inputs):
masked_w = tf.math.multiply(self.w, self.mask)
output=tf.nn.conv2d(inputs, masked_w, 1, "SAME") + self.b
return output
網絡架構
PixelCNN架構非常簡單。在使用A型掩膜
的第一個7 x 7 conv2d圖層之后,有幾層帶有B型掩膜
的殘差塊:
Model: "PixelCnn"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
masked_conv2d_22 (MaskedConv (None, 28, 28, 128) 6400
_________________________________________________________________
residual_block_7 (ResidualBl (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_8 (ResidualBl (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_9 (ResidualBl (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_10 (ResidualB (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_11 (ResidualB (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_12 (ResidualB (None, 28, 28, 128) 53504
_________________________________________________________________
residual_block_13 (ResidualB (None, 28, 28, 128) 53504
_________________________________________________________________
conv2d_2 (Conv2D) (None, 28, 28, 64) 8256
_________________________________________________________________
conv2d_3 (Conv2D) (None, 28, 28, 1) 65
=================================================================
Total params: 389,249
Trainable params: 389,249
Non-trainable params: 0
下圖說明了PixelCNN中使用的殘差塊架構:
交叉熵損失
交叉熵損失也稱為對數損失,它衡量模型的性能,其中輸出的概率在0到1之間。以下是二進制交叉熵損失的方程,其中只有兩個類,標簽y可以是0或1, p ( x ) p(x) p(x)是模型的預測:
B C E = − 1 N ∑ i = 1 N ( y i l o g p ( x ) + ( 1 − y i ) l o g ( 1 − p ( x ) ) ) BCE = -\frac1N\sum_{i=1}^N(y_ilogp(x)+(1-y_i)log(1-p(x))) BCE=−N1i=1∑N(yilogp(x)+(1−yi)log(1−p(x)))
在PixelCNN中,單個圖像像素用作標簽。在二值化MNIST中,我們要預測輸出像素是0還是1,這使其成為使用交叉熵作為損失函數的分類問題。
最后,編譯和訓練神經網絡,我們對損失和度量均使用二進制交叉熵,並使用RMSprop作為優化器。有許多不同的優化器可供使用,它們的主要區別在於它們根據過去的統計信息調整學習率的方式。沒有一種最佳的優化器可以在所有情況下使用,因此建議嘗試使用不同的優化器。
編譯和訓練pixelcnn模型:
pixelcnn = SimplePixelCnn()
pixelcnn.compile(
loss = tf.keras.losses.BinaryCrossentropy(),
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
metrics=[ tf.keras.metrics.BinaryCrossentropy()])
pixelcnn.fit(ds_train, epochs = 10, validation_data=ds_test)
接下來,我們將根據先前的模型生成一個新圖像。
采樣生成圖片
訓練后,我們可以通過以下步驟使用該模型生成新圖像:
- 創建一個具有與輸入圖像相同形狀的空張量,並用
1
填充。將其饋入網絡並獲得 p ( x 1 ) p(x1) p(x1),即第一個像素的概率。 - 從 p ( x 1 ) p(x_1) p(x1)進行采樣,並將采樣值分配給輸入張量中的像素 x 1 x_1 x1。
- 再次將輸入提供給網絡,並對下一個像素執行步驟2。
- 重復步驟2和3,直到生成 x N x_N xN。
自回歸模型的一個主要缺點是它生成速度慢,因為需要逐像素生成,而無法並行化。以下圖像是我們的PixelCNN模型經過100個訓練周期后生成的。它們看起來還不太像正確的數字,但我們現在可以憑空生成新圖像。可以通過訓練更長的模型並進行一些超參數調整來生成更好的數字。
完整代碼
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
print(tf.__version__)
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['test', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True)
fig = tfds.show_examples(ds_train, ds_info)
def binarize(image, label):
image = tf.cast(image, tf.float32)
image = tf.math.round(image/255.)
return image, tf.cast(image, tf.int32)
ds_train = ds_train.map(binarize)
ds_train = ds_train.cache() # put dataset into memory
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(binarize).batch(128).prefetch(64)
class MaskedConv2D(layers.Layer):
def __init__(self, mask_type, kernel=5, filters=1):
super(MaskedConv2D, self).__init__()
self.kernel = kernel
self.filters = filters
self.mask_type = mask_type
def build(self, input_shape):
self.w = self.add_weight(shape=[
self.kernel,
self.kernel,
input_shape[-1],
self.filters],
initializer='glorot_normal',
trainable=True)
self.b = self.add_weight(shape=(self.filters,),
initializer='zeros',
trainable=True)
# Create Mask
mask = np.ones(self.kernel ** 2, dtype=np.float32)
center = len(mask) // 2
mask[center+1:] = 0
if self.mask_type == 'A':
mask[center] = 0
mask = mask.reshape((self.kernel, self.kernel, 1, 1))
self.mask = tf.constant(mask, dtype='float32')
def call(self, inputs):
# mask the convolution
masked_w = tf.math.multiply(self.w, self.mask)
# preform conv2d using low level API
output = tf.nn.conv2d(inputs, masked_w, 1, 'SAME') + self.b
return tf.nn.relu(output)
class ResidualBlock(layers.Layer):
def __init__(self, h=32):
super(ResidualBlock, self).__init__()
self.forward = Sequential([
MaskedConv2D('B', kernel=1, filters=h),
MaskedConv2D('B', kernel=3, filters=h),
MaskedConv2D('B', kernel=1, filters=2*h),
])
def call(self, inputs):
x = self.forward(inputs)
return x + inputs
def SimplePixelCnn(
hidden_features=64,
output_features=64,
resblocks_num=7):
inputs = layers.Input(shape=[28, 28, 1])
x = inputs
x = MaskedConv2D('A', kernel=7, filters=2*hidden_features)(x)
for _ in range(resblocks_num):
x = ResidualBlock(hidden_features)(x)
x = layers.Conv2D(output_features, (1,1), padding='same', activation='relu')(x)
x = layers.Conv2D(1, (1,1), padding='same', activation='sigmoid')(x)
return tf.keras.Model(inputs=inputs, outputs=x, name='PixelCnn')
pixel_cnn = SimplePixelCnn()
pixel_cnn.summary()
pixel_cnn.compile(
loss=tf.keras.losses.BinaryCrossentropy(),
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
metrics=[tf.keras.losses.BinaryCrossentropy()]
)
grid_row = 5
grid_col = 5
batch = grid_row * grid_col
h = w = 28
images = np.ones((batch,h,w,1), dtype=np.float32)
for row in range(h):
for col in range(w):
prob = pixel_cnn.predict(images)[:, row,col,0]
pixel_samples = tf.random.categorical(
tf.math.log(np.stack([1-prob,prob],1)),1
)
#print(pixel_samples.shape)
images[:,row,col,0] = tf.reshape(pixel_samples, [batch])
# Display
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.1,grid_row))
i = 0
for row in range(grid_row):
for col in range(grid_col):
axarr[row,col].imshow(images[i,:,:,0], cmap='gray')
axarr[row,col].axis('off')
i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()