生成模型——自回歸模型詳解與PixelCNN構建


自回歸模型(Autoregressive models)

深度神經網絡生成算法主要分為三類:

  1. 生成對抗網絡(Generative Adversarial Network, GAN)
  2. 可變自動編碼器(Variational Autoencoder, VAE)
  3. 自回歸模型(Autoregressive models)

VAE已經在《變分自編碼器(VAE)原理與實現(tensorflow2.x)》中進行了介紹。GAN的詳細信息參考《深度卷積生成對抗網絡(DCGAN)原理與實現(采用Tensorflow2.x)》。將在這里介紹鮮為人知的自回歸模型,盡管自回歸在圖像生成中並不常見,但自回歸仍然是研究的活躍領域,DeepMind的WaveNet使用自回歸來生成逼真的音頻。在本文中,將介紹自回歸模型並構建PixelCNN模型。

簡介

Autoregressive中的“Auto”意味着自我(self),而機器學習術語的回歸(regress)意味着預測新的值。將它們放在一起,自回歸意味着我們使用模型基於模型的過去數據點來預測新數據點。
設圖像的概率分布是 p ( x ) p(x) p(x)是像素的聯合概率分布 p ( x 1 , x 2 , … x n ) p(x_1, x_2, …x_n) p(x1,x2,xn),由於高維數很難建模。在這里,我們假設一個像素的值僅取決於它之前的像素的值。換句話說,當前像素僅以其前一像素為條件,即 p ( x i ) = p ( x i ∣ x i − 1 ) p ( x i − 1 ) p(x_i) = p(x_i | x_{i-1}) p(x_{i-1}) p(xi)=p(xixi1)p(xi1),我們就可以將聯合概率近似為條件概率的乘積:
p ( x ) = p ( x n , x n − 1 , … , x 2 , x 1 ) p(x) = p(x_n, x_{n-1}, …, x_2, x_1) p(x)=p(xn,xn1,,x2,x1)
p ( x ) = p ( x n ∣ x n − 1 ) . . . p ( x 3 ∣ x 2 ) p ( x 2 ∣ x 1 ) p ( x 1 ) p(x) = p(x_n | x_{n-1})... p(x_3 | x_2) p(x_2 | x_1) p(x_1) p(x)=p(xnxn1)...p(x3x2)p(x2x1)p(x1)
舉一個具體的例子,假設在圖像的中心附近包含一個紅色的蘋果,並且該蘋果被綠葉包圍,在此情況下,假設僅存在兩種可能的顏色:紅色和綠色。 x 1 x_1 x1是左上像素,所以 p ( x 1 ) p(x_1) p(x1)表示左上像素是綠色還是紅色的概率。如果 x 1 x_1 x1為綠色,則其右邊 p ( x 2 ) p(x_2) p(x2)的像素也可能也為綠色,因為它可能會有更多的葉子。但是,盡管可能性較小,但它也可能是紅色的。
繼續進行計算,我們最終將得到紅色像素。從那個像素開始,接下來的幾個像素也很可能也是紅色的,這比必須同時考慮所有像素要簡單得多。

PixelRNN

PixelRNN由DeepMind於2016年提出。正如名稱RNN(Recurrent Neural Network, 遞歸神經網絡)所暗示的那樣,該模型使用一種稱為長短期記憶(LSTM)的RNN來學習圖像的分布。它在LSTM中的一個步驟中一次讀取圖像的一行,並使用一維卷積層對其進行處理,然后將激活信息饋送到后續層中以預測該行的像素。
由於LSTM運行緩慢,因此需要花費很長時間來訓練和生成樣本。因此,我們不會對其進行過多的研究,而將注意力轉移到同一論文中提出的一種變體——PixelCNN。

使用TensorFlow 2構建PixelCNN模型

PixelCNN僅由卷積層組成,使其比PixelRNN快得多。在這里,我們將為使用MNIST數據集訓練一個簡單的PixelCNN模型。

輸入和標簽

MNIST由28 x 28 x 1灰度數字手寫數字組成。它只有一個通道:
數據集查看
在本實驗中,通過將圖像轉換為二進制數據來簡化問題:0代表黑色,1代表白色:

def binarize(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.math.round(image/255.)
    return image, tf.cast(image, tf.int32)

該函數需要兩個輸入——圖像和標簽。該函數的前兩行將圖像轉換為二進制float32格式,即0.0或1.0。並且,我們將二進制圖像轉換為整數並返回,以遵循使用整數作為標簽的慣例而已。返回的數據,將作為網絡訓練的輸入和標簽,都是28 x 28 x 1的二進制MNIST圖像,它們僅在數據類型上有所不同。

掩膜

與PixelRNN逐行讀取不同,PixelCNN在圖像中從左到右,從上到下滑動卷積核。當執行卷積以預測當前像素時,傳統的卷積核能夠看到當前輸入像素及其周圍的像素,其中包括當前像素之后的像素信息,這與在簡介部分的條件概率假設相悖。
為了避免這種情況,我們需要確保CNN在預測輸出像素 x i x_i xi時不會看到輸入像素 x i x_i xi
這是通過使用掩膜卷積來實現的,其中在執行卷積之前將掩膜應用於卷積核權重。下圖顯示了一個7 x 7卷積核的掩膜,其中從中心開始的權重為0。這會阻止CNN看到它正在預測的像素(卷積核的中心)以及所有之后的像素。這稱為A型掩膜,僅應用於輸入層。
掩膜卷積核

由於中心像素在第一層中被遮擋,因此我們不再需要在后面的層中隱藏中心要素。實際上,我們需要將卷積核中心設置為1,以使其能夠讀取先前層的特征,這稱為B型掩膜

實現自定義層

現在,我們將為掩膜卷積創建一個自定義層。我們可以使用從基類tf.keras.layers.Layer繼承的子類在TensorFlow2.x中創建自定義層,以便將能夠像使用其他Keras層一樣使用它。以下是自定義層類的基本結構:

class MaskedConv2D(tf.keras.layers.Layer):
    def __init__(self):
        ...       
    def build(self, input_shape):
        ...
    def call(self, inputs):
        ...
        return output

build()將輸入張量的形狀作為參數,我們將使用此信息來確保創建正確形狀的變量。構建圖層時,此函數僅運行一次。我們可以通過聲明不可訓練的變量或常量來創建掩碼,以使TensorFlow知道它不需要梯度來反向傳播:

	def build(self, input_shape):
	        self.w = self.add_weight(shape=[self.kernel,
	                                        self.kernel,
	                                        input_shape[-1],
	                                        self.filters],
	                                initializer='glorot_normal',
	                                trainable=True)
	        self.b = self.add_weight(shape=(self.filters,),
	                                initializer='zeros',
	                                trainable=True)
	        mask = np.ones(self.kernel**2, dtype=np.float32)
	        center = len(mask)//2
	        mask[center+1:] = 0
	        if self.mask_type == 'A':
	            mask[center] = 0
	        mask = mask.reshape((self.kernel, self.kernel, 1, 1))
	        self.mask = tf.constant(mask, dtype='float32')

call()用來執行前向傳遞。在掩膜卷積層中,在使用低級tf.nn API執行卷積之前,我們將權重乘以掩碼后將下半部分的值設為零:

	def call(self, inputs):
	        masked_w = tf.math.multiply(self.w, self.mask)
	        output=tf.nn.conv2d(inputs, masked_w, 1, "SAME") + self.b
	        return output

網絡架構

PixelCNN架構非常簡單。在使用A型掩膜的第一個7 x 7 conv2d圖層之后,有幾層帶有B型掩膜的殘差塊:

Model: "PixelCnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param # 
=================================================================
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
masked_conv2d_22 (MaskedConv (None, 28, 28, 128)       6400      
_________________________________________________________________
residual_block_7 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_8 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_9 (ResidualBl (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_10 (ResidualB (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_11 (ResidualB (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_12 (ResidualB (None, 28, 28, 128)       53504     
_________________________________________________________________
residual_block_13 (ResidualB (None, 28, 28, 128)       53504     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        8256      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 1)         65        
=================================================================
Total params: 389,249
Trainable params: 389,249
Non-trainable params: 0

下圖說明了PixelCNN中使用的殘差塊架構:
殘差塊架構

交叉熵損失

交叉熵損失也稱為對數損失,它衡量模型的性能,其中輸出的概率在0到1之間。以下是二進制交叉熵損失的方程,其中只有兩個類,標簽y可以是0或1, p ( x ) p(x) p(x)是模型的預測:
B C E = − 1 N ∑ i = 1 N ( y i l o g p ( x ) + ( 1 − y i ) l o g ( 1 − p ( x ) ) ) BCE = -\frac1N\sum_{i=1}^N(y_ilogp(x)+(1-y_i)log(1-p(x))) BCE=N1i=1N(yilogp(x)+(1yi)log(1p(x)))
在PixelCNN中,單個圖像像素用作標簽。在二值化MNIST中,我們要預測輸出像素是0還是1,這使其成為使用交叉熵作為損失函數的分類問題。
最后,編譯和訓練神經網絡,我們對損失和度量均使用二進制交叉熵,並使用RMSprop作為優化器。有許多不同的優化器可供使用,它們的主要區別在於它們根據過去的統計信息調整學習率的方式。沒有一種最佳的優化器可以在所有情況下使用,因此建議嘗試使用不同的優化器。
編譯和訓練pixelcnn模型:

pixelcnn = SimplePixelCnn()
pixelcnn.compile(
    loss = tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    metrics=[ tf.keras.metrics.BinaryCrossentropy()])
pixelcnn.fit(ds_train, epochs = 10, validation_data=ds_test)

接下來,我們將根據先前的模型生成一個新圖像。

采樣生成圖片

訓練后,我們可以通過以下步驟使用該模型生成新圖像:

  1. 創建一個具有與輸入圖像相同形狀的空張量,並用1填充。將其饋入網絡並獲得 p ( x 1 ) p(x1) p(x1),即第一個像素的概率。
  2. p ( x 1 ) p(x_1) p(x1)進行采樣,並將采樣值分配給輸入張量中的像素 x 1 x_1 x1
  3. 再次將輸入提供給網絡,並對下一個像素執行步驟2。
  4. 重復步驟2和3,直到生成 x N x_N xN

自回歸模型的一個主要缺點是它生成速度慢,因為需要逐像素生成,而無法並行化。以下圖像是我們的PixelCNN模型經過100個訓練周期后生成的。它們看起來還不太像正確的數字,但我們現在可以憑空生成新圖像。可以通過訓練更長的模型並進行一些超參數調整來生成更好的數字。
生成結果圖片

完整代碼

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')
print(tf.__version__)

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['test', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True)
fig = tfds.show_examples(ds_train, ds_info)

def binarize(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.math.round(image/255.)
    return image, tf.cast(image, tf.int32)
    
ds_train = ds_train.map(binarize)
ds_train = ds_train.cache() # put dataset into memory
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(binarize).batch(128).prefetch(64)

class MaskedConv2D(layers.Layer):
    def __init__(self, mask_type, kernel=5, filters=1):
        super(MaskedConv2D, self).__init__()
        self.kernel = kernel
        self.filters = filters
        self.mask_type = mask_type
    
    def build(self, input_shape):
        
        self.w = self.add_weight(shape=[
            self.kernel, 
            self.kernel, 
            input_shape[-1],
            self.filters],
            initializer='glorot_normal',
            trainable=True)
        
        self.b = self.add_weight(shape=(self.filters,),
            initializer='zeros',
            trainable=True)
        
        # Create Mask
        mask = np.ones(self.kernel ** 2, dtype=np.float32)
        center = len(mask) // 2
        mask[center+1:] = 0
        if self.mask_type == 'A':
            mask[center] = 0
        
        mask = mask.reshape((self.kernel, self.kernel, 1, 1))

        self.mask = tf.constant(mask, dtype='float32')
    
    def call(self, inputs):
        # mask the convolution
        masked_w = tf.math.multiply(self.w, self.mask)

        # preform conv2d using low level API
        output = tf.nn.conv2d(inputs, masked_w, 1, 'SAME') + self.b

        return tf.nn.relu(output)

class ResidualBlock(layers.Layer):
    def __init__(self, h=32):
        super(ResidualBlock, self).__init__()

        self.forward = Sequential([
            MaskedConv2D('B', kernel=1, filters=h),
            MaskedConv2D('B', kernel=3, filters=h),
            MaskedConv2D('B', kernel=1, filters=2*h),
        ])
    
    def call(self, inputs):
        x = self.forward(inputs)
        return x + inputs

def SimplePixelCnn(
    hidden_features=64,
    output_features=64,
    resblocks_num=7):
    
    inputs = layers.Input(shape=[28, 28, 1])
    x = inputs
    
    x = MaskedConv2D('A', kernel=7, filters=2*hidden_features)(x)

    for _ in range(resblocks_num):
        x = ResidualBlock(hidden_features)(x)
    
    x = layers.Conv2D(output_features, (1,1), padding='same', activation='relu')(x)
    x = layers.Conv2D(1, (1,1), padding='same', activation='sigmoid')(x)

    return tf.keras.Model(inputs=inputs, outputs=x, name='PixelCnn')

pixel_cnn = SimplePixelCnn()
pixel_cnn.summary()

pixel_cnn.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    metrics=[tf.keras.losses.BinaryCrossentropy()]
)

grid_row = 5
grid_col = 5
batch = grid_row * grid_col
h = w = 28
images = np.ones((batch,h,w,1), dtype=np.float32)

for row in range(h):
    for col in range(w):
        prob = pixel_cnn.predict(images)[:, row,col,0]
        pixel_samples = tf.random.categorical(
            tf.math.log(np.stack([1-prob,prob],1)),1
        )
        #print(pixel_samples.shape)
        images[:,row,col,0] = tf.reshape(pixel_samples, [batch])

# Display
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.1,grid_row))

i = 0
for row in range(grid_row):
    for col in range(grid_col):
        axarr[row,col].imshow(images[i,:,:,0], cmap='gray')
        axarr[row,col].axis('off')
        i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)        
plt.show()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM