tfgan折騰筆記(一):核心功能簡要概述


tfgan是什么?

tfgan是tensorflow團隊開發出的一個專門用於訓練各種GAN的輕量級庫,它是基於tensorflow開發的,所以兼容於tensorflow。在tensorflow1.x版本中,tfgan存在於tensorflow.contrib中,作為一個小模塊供使用者調用。在更新到tensorflow2.0版本后,tfgan成為一個獨立的庫。可使用:

pip install tensorflow-gan

 進行下載安裝,並在python中使用以下語句導入這個包:

import tensorflow_gan as tfgan

 可以使用tfgan對目前流行的GAN模型進行訓練。並且,tfgan維護團隊也會不斷更新tfgan,使得其可以對論文中最新提出的GAN模型進行訓練。

tfgan項目托管在github中,點擊這里可以查看tfgan在github中托管的源代碼及其官方教程與示例。

tfgan核心功能

tfgan的中函數的功能主要集中在基於tensorflow的LOSS函數、優化器、訓練迭代的封裝,以及對GAN模型的評估。其它的如數據集的輸入、生成器和判別器模型的結構以及推斷過程則需要通過調用tensorflow函數自己編寫。即使這樣,tfgan也極大的簡化了GAN的訓練與實現。接下來就針對tfgan中的幾個核心功能對應的函數進行一個預覽,以便對tfgan有一個初步印象。具體的用法將在后續文章中詳細說明注意:以下的代碼中的函數均為調用,而不是函數原型

tfgan核心函數示例

·初始化模型

以Original-GAN為例進行說明,其它的例如C-GAN, info-GAN, Cycle-GAN等的情況與此處略有不同,在后續文章中會有具體說明

gan_model = tfgan.gan_model(
    generator_fn=generator,
    discriminator_fn=discriminator,
    real_data=images,
    generator_inputs=tf.random.normal(
        [batch_size, noise_dims]
    )
)

 在tfgan中,調用gan_model函數以創建Original-GAN網絡模型,其主要參數包含4個,以下詳細說明:

generator_fn:需要先自定義一個生成器函數,函數中定義判別器網絡模型,並將函數名稱作為參數傳入。定義的生成器函數的接口應當符合如下格式:

def generator(noise, weight_decay=2.5e-5, is_training=True):
    '''GAN Generator.

    Args:
        noise: A 2D Tensor of shape [batch size, noise dim].
        weight_decay: The value of the l2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population
            statistics.

    Returns:
        A generated image.
    '''

 discriminator_fn:同樣,需要首先自定義一個判別器函數,函數中定義判別器網絡模型。並將函數名稱作為參數傳入。定義的判別器函數的接口應當符合如下格式:

def discriminator(img, unused_conditioning, weight_decay=2.5e-5):
    '''GAN discriminator.

    Args:
        img: Real or generated MNIST digits. Should be in the range [-1, 1].
        unuseed_conditioning: The TFGAN API can help with conditional GANs, which
            would require extra `condition` information to both the generator and the
            discriminator. Since this example is not conditional, we do not use this
            argument.
        weight_decay: The L2 weight decay.

    Returns:
        Logits for the probability that the image is real.
    '''

 real_data:真實圖像。一個batch的Tensor格式。

generator_inputs:輸入GAN的隨機噪聲,一般通過tf.random.normal()函數獲得。

·指定損失函數

使用gan_loss函數指定訓練GAN時所需要的損失函數,若調用形式如下所示,使用默認的損失函數:

gan_loss = tfgan.gan_loss(gan_model, add_summaries=True)

 gan_model:上一步初始化模型時的返回值。

add_summaries:是否添加損失的總結。tfgan在訓練時,會自動生成tensorboard的日志信息(日志的位置將在最后一步“gan_train”函數中指定,tensorboard是一個適配於tensorflow的訓練過程可視化工具),若為True,將添加loss的信息到日志中。

或者使用tfgan中內置的其它loss函數,下面的函數調用時就使用了帶權重懲罰的W距離。或者可以自己自定義loss函數,此處不再詳述。

gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.modified_generator_loss,
    discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
    mutual_information_penalty_weight=1.0,
    add_summaries=True
)

·指定優化器

train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-3, 0.5),
    discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(3e-4, 0.5),
    summarize_gradients=True
)

 優化器一般需要傳遞4個參數:

gan_model:第一步調用tfgan.gan_model的返回值

gan_loss:第二步調用tfgan.gan_loss的返回值

generator_optimizer:指定生成器的優化器

discriminator_optimizer:指定判別器的優化器

summarize_gradients:添加梯度的總結

·開始訓練

tfgan.gan_train(
    train_ops,
    hooks=[
        tf.estimator.StopAtStepHook(num_steps=max_number_of_steps),
        tf.estimator.LoggingTensorHook([status_message], every_n_iter=20)
    ],
    logdir=train_log_dir,
    get_hooks_fn=tfgan.get_joint_train_hooks(),
    save_checkpoint_secs=60
)

 參數解釋:

train_ops:上一步函數的返回值

hooks:tf.train.SessionRunHook類型的回調函數,用列表形式封裝。此處的函數將在每次訓練迭代時調用

logdir:tfgan自動將建立好的網絡模型以及訓練過程的參數變化存儲下來,此參數即為存儲的位置

get_hooks_fn:G和D的訓練方式,get_joint_train_hooks()意為進行一次G+D的參數更新,然后再單獨進行一次D的參數更新。以此為一個迭代周期。

save_checkpoint_secs:訓練過程中參數存儲周期,此處設置為60s存儲一次網絡參數。

調用gan_train函數后,訓練開始進行。

使用tfgan進行GAN網絡訓練步驟:

1.定義Generator與Discriminator網絡模型;

2.加載訓練集數據為batch形式;

3.調用gan_model函數以初始化網絡模型;

4.調用gan_loss函數以指定損失函數;

5.調用gan_train_ops函數以指定優化器;

6.調用gan_train函數開始訓練;

7.訓練完畢后,tfgan自動將網絡模型及參數以及訓練過程的總結(summarise)存儲在硬盤中。

使用tfgan進行推斷的步驟:

1.從tfgan保存的日志中加載網絡模型及參數;

2.加載測試數據;

3.將數據傳入(feed)網絡,得到結果。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM