使用Keras編寫GAN的入門


使用Keras編寫GAN的入門

Time: 2017-5-31


前言

主要參考了網頁[1]的教程,同時主要算法來自Ian J. Goodfellow 的論文,算法如下:

gan

gan

代碼

%matplotlib inline
import numpy as np
import pandas as pd

from keras.models import Model
from keras.layers import Dense, Activation, Input, Reshape
from keras.layers import Conv1D, Flatten, Dropout
from keras.optimizers import SGD, Adam


from tqdm import tqdm_notebook as tqdm  # 進度條


# 生成隨機正弦曲線的數據
def sample_data(n_samples=10000, x_vals=np.arange(0, 5, .1), max_offset=1000, mul_range=[1, 2]):
    vectors = []
    for i in range(n_samples):
        offset = np.random.random() * max_offset
        mul = mul_range[0] + np.random.random() * (mul_range[1] - mul_range[0])
        vectors.append(np.sin(offset + x_vals * mul) / 2 + .5)
        
    return np.array(vectors)
    
# 創建生成模型
def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
    x = Dense(dense_dim)(G_in)
    x = Activation('tanh')(x)
    G_out = Dense(out_dim, activation='tanh')(x)
    G = Model(G_in, G_out)
    opt = SGD(lr=lr)
    
    G.compile(loss='binary_crossentropy', optimizer=opt)
    
    return G, G_out
    
# 創建判別模型
def get_discriminative(D_in, lr=1e-3, drate = .25, n_channels=50, conv_sz=5, leak=.2):
    x = Reshape((-1, 1))(D_in)
    x = Conv1D(n_channels, conv_sz, activation='relu')(x)
    x = Dropout(drate)(x)
    x = Flatten()(x)
    x = Dense(n_channels)(x)
    D_out = Dense(2, activation='sigmoid')(x)
    D = Model(D_in, D_out)
    dopt = Adam(lr=lr)
    D.compile(loss='binary_crossentropy', optimizer=dopt)
    
    return D, D_out

    
    
def set_trainability(model, trainable=False):
    model.trainable = trainable
    for layer in model.layers:
        layer.trainable = trainable
        
def make_gan(GAN_in, G, D):
    set_trainability(D, False)
    x = G(GAN_in)
    GAN_out = D(x)
    GAN = Model(GAN_in, GAN_out)
    GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
    return GAN, GAN_out

# 通過生成數據 預訓練判別模型
def sample_data_and_gen(G, noise_dim=10, n_samples=10000):
    XT = sample_data(n_samples=n_samples)
    XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    XN = G.predict(XN_noise)
    X = np.concatenate((XT, XN))
    y = np.zeros((2*n_samples, 2))
    y[:n_samples, 1] = 1
    y[n_samples:, 0] = 1

    return X, y
     
def pretrain(G, D, noise_dim=10, n_samples=10000, batch_size=32):
    X, y = sample_data_and_gen(G, noise_dim=noise_dim, n_samples=n_samples)
    set_trainability(D, True)
    D.fit(X, y, epochs=1, batch_size=batch_size)
    
    
# 開始交叉訓練步驟
def sample_noise(G, noise_dim=10, n_samples=10000):
    X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    y = np.zeros((n_samples, 2))
    y[:, 1] = 1

    return X, y
    
def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
    d_loss = []
    g_loss = []
    e_range = range(epochs)
    if verbose:
        e_range = tqdm(e_range)
    
    for epoch in e_range:
        X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) # 對D進行訓練
        set_trainability(D, True)
        d_loss.append(D.train_on_batch(X, y))
        
        X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) # 對G訓練
        set_trainability(D, False)
        g_loss.append(GAN.train_on_batch(X, y))
        
        if verbose and (epoch + 1) % v_freq == 0:
            print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
            
    return d_loss, g_loss

ax = pd.DataFrame(np.transpose(sample_data(5))).plot()
G_in = Input(shape=[10])
G, G_out = get_generative(G_in)
G.summary()

D_in = Input(shape=[50])
D, D_out = get_discriminative(D_in)
D.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 10)                0         
_________________________________________________________________
dense_13 (Dense)             (None, 200)               2200      
_________________________________________________________________
activation_4 (Activation)    (None, 200)               0         
_________________________________________________________________
dense_14 (Dense)             (None, 50)                10050     
=================================================================
Total params: 12,250
Trainable params: 12,250
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        (None, 50)                0         
_________________________________________________________________
reshape_4 (Reshape)          (None, 50, 1)             0         
_________________________________________________________________
conv1d_4 (Conv1D)            (None, 46, 50)            300       
_________________________________________________________________
dropout_4 (Dropout)          (None, 46, 50)            0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 2300)              0         
_________________________________________________________________
dense_15 (Dense)             (None, 50)                115050    
_________________________________________________________________
dense_16 (Dense)             (None, 2)                 102       
=================================================================
Total params: 115,452
Trainable params: 115,452
Non-trainable params: 0
_________________________________________________________________

png

png

GAN_in = Input([10])
GAN, GAN_out = make_gan(GAN_in, G, D)
GAN.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        (None, 10)                0         
_________________________________________________________________
model_9 (Model)              (None, 50)                12250     
_________________________________________________________________
model_10 (Model)             (None, 2)                 115452    
=================================================================
Total params: 127,702
Trainable params: 12,250
Non-trainable params: 115,452
_________________________________________________________________
pretrain(G, D)
Epoch 1/1
20000/20000 [==============================] - 3s - loss: 0.0072     
d_loss, g_loss = train(GAN, G, D, verbose=True)
Epoch #50: Generative Loss: 4.41527795791626, Discriminative Loss: 0.6733301877975464
Epoch #100: Generative Loss: 3.8898046016693115, Discriminative Loss: 0.09901376813650131
Epoch #150: Generative Loss: 6.2410054206848145, Discriminative Loss: 0.034074194729328156
Epoch #200: Generative Loss: 5.206066608428955, Discriminative Loss: 0.13078376650810242
Epoch #250: Generative Loss: 3.5144925117492676, Discriminative Loss: 0.07160962373018265
Epoch #300: Generative Loss: 3.705162525177002, Discriminative Loss: 0.05893774330615997
Epoch #350: Generative Loss: 3.511479616165161, Discriminative Loss: 0.09775738418102264
Epoch #400: Generative Loss: 4.141300678253174, Discriminative Loss: 0.03169865906238556
Epoch #450: Generative Loss: 3.500260829925537, Discriminative Loss: 0.05957922339439392
Epoch #500: Generative Loss: 2.9797921180725098, Discriminative Loss: 0.10566817969083786
ax = pd.DataFrame(
    {
        'Generative Loss': g_loss,
        'Discriminative Loss': d_loss,
    }
).plot(title='Training loss', logy=True)
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")

png

png

N_VIEWED_SAMPLES = 2
data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).plot()

png

png

N_VIEWED_SAMPLES = 2
data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()

png

png

reference

[1] http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html#Imports


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM