TF Boys (TensorFlow Boys ) 養成記(三): TensorFlow 變量共享


上次說到了 TensorFlow 從文件讀取數據,這次我們來談一談變量共享的問題。

為什么要共享變量?我舉個簡單的例子:例如,當我們研究生成對抗網絡GAN的時候,判別器的任務是,如果接收到的是生成器生成的圖像,判別器就嘗試優化自己的網絡結構來使自己輸出0,如果接收到的是來自真實數據的圖像,那么就嘗試優化自己的網絡結構來使自己輸出1。也就是說,生成圖像和真實圖像經過判別器的時候,要共享同一套變量,所以TensorFlow引入了變量共享機制。

變量共享主要涉及到兩個函數: tf.get_variable(<name>, <shape>, <initializer>) 和 tf.variable_scope(<scope_name>)。

先來看第一個函數: tf.get_variable。

tf.get_variable 和tf.Variable不同的一點是,前者擁有一個變量檢查機制,會檢測已經存在的變量是否設置為共享變量,如果已經存在的變量沒有設置為共享變量,TensorFlow 運行到第二個擁有相同名字的變量的時候,就會報錯。

例如如下代碼:

def my_image_filter(input_images):
    conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]),
        name="conv1_weights")
    conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases")
    conv1 = tf.nn.conv2d(input_images, conv1_weights,
        strides=[1, 1, 1, 1], padding='SAME')
    return  tf.nn.relu(conv1 + conv1_biases)

有兩個變量(Variables)conv1_weighs, conv1_biases和一個操作(Op)conv1,如果你直接調用兩次,不會出什么問題,但是會生成兩套變量;

# First call creates one set of 2 variables.
result1 = my_image_filter(image1)
# Another set of 2 variables is created in the second call.
result2 = my_image_filter(image2)

如果把 tf.Variable 改成 tf.get_variable,直接調用兩次,就會出問題了:

result1 = my_image_filter(image1)
result2 = my_image_filter(image2)
# Raises ValueError(... conv1/weights already exists ...)

為了解決這個問題,TensorFlow 又提出了 tf.variable_scope 函數:它的主要作用是,在一個作用域 scope 內共享一些變量,可以有如下幾種用法:

1)

with tf.variable_scope("image_filters") as scope:
    result1 = my_image_filter(image1)
    scope.reuse_variables() # or 
    #tf.get_variable_scope().reuse_variables()
    result2 = my_image_filter(image2)

需要注意的是:最好不要設置 reuse 標識為 False,只在需要的時候設置 reuse 標識為 True。

2)

with tf.variable_scope("image_filters1") as scope1:
    result1 = my_image_filter(image1)
with tf.variable_scope(scope1, reuse = True)
    result2 = my_image_filter(image2)

 

 

通常情況下,tf.variable_scope 和 tf.name_scope 配合,能畫出非常漂亮的流程圖,但是他們兩個之間又有着細微的差別,那就是 name_scope 只能管住操作 Ops 的名字,而管不住變量 Variables 的名字,看下例:

with tf.variable_scope("foo"):
    with tf.name_scope("bar"):
        v = tf.get_variable("v", [1])
        x = 1.0 + v
assert v.name == "foo/v:0"
assert x.op.name == "foo/bar/add"

 

 

參考資料:

1. https://www.tensorflow.org/how_tos/variable_scope/

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM