上次說到了 TensorFlow 從文件讀取數據,這次我們來談一談變量共享的問題。
為什么要共享變量?我舉個簡單的例子:例如,當我們研究生成對抗網絡GAN的時候,判別器的任務是,如果接收到的是生成器生成的圖像,判別器就嘗試優化自己的網絡結構來使自己輸出0,如果接收到的是來自真實數據的圖像,那么就嘗試優化自己的網絡結構來使自己輸出1。也就是說,生成圖像和真實圖像經過判別器的時候,要共享同一套變量,所以TensorFlow引入了變量共享機制。
變量共享主要涉及到兩個函數: tf.get_variable(<name>, <shape>, <initializer>) 和 tf.variable_scope(<scope_name>)。
先來看第一個函數: tf.get_variable。
tf.get_variable 和tf.Variable不同的一點是,前者擁有一個變量檢查機制,會檢測已經存在的變量是否設置為共享變量,如果已經存在的變量沒有設置為共享變量,TensorFlow 運行到第二個擁有相同名字的變量的時候,就會報錯。
例如如下代碼:
def my_image_filter(input_images): conv1_weights = tf.Variable(tf.random_normal([5, 5, 32, 32]), name="conv1_weights") conv1_biases = tf.Variable(tf.zeros([32]), name="conv1_biases") conv1 = tf.nn.conv2d(input_images, conv1_weights, strides=[1, 1, 1, 1], padding='SAME') return tf.nn.relu(conv1 + conv1_biases)
有兩個變量(Variables)conv1_weighs, conv1_biases和一個操作(Op)conv1,如果你直接調用兩次,不會出什么問題,但是會生成兩套變量;
# First call creates one set of 2 variables. result1 = my_image_filter(image1) # Another set of 2 variables is created in the second call. result2 = my_image_filter(image2)
如果把 tf.Variable 改成 tf.get_variable,直接調用兩次,就會出問題了:
result1 = my_image_filter(image1) result2 = my_image_filter(image2) # Raises ValueError(... conv1/weights already exists ...)
為了解決這個問題,TensorFlow 又提出了 tf.variable_scope 函數:它的主要作用是,在一個作用域 scope 內共享一些變量,可以有如下幾種用法:
1)
with tf.variable_scope("image_filters") as scope: result1 = my_image_filter(image1) scope.reuse_variables() # or #tf.get_variable_scope().reuse_variables() result2 = my_image_filter(image2)
需要注意的是:最好不要設置 reuse 標識為 False,只在需要的時候設置 reuse 標識為 True。
2)
with tf.variable_scope("image_filters1") as scope1: result1 = my_image_filter(image1) with tf.variable_scope(scope1, reuse = True) result2 = my_image_filter(image2)
通常情況下,tf.variable_scope 和 tf.name_scope 配合,能畫出非常漂亮的流程圖,但是他們兩個之間又有着細微的差別,那就是 name_scope 只能管住操作 Ops 的名字,而管不住變量 Variables 的名字,看下例:
with tf.variable_scope("foo"): with tf.name_scope("bar"): v = tf.get_variable("v", [1]) x = 1.0 + v assert v.name == "foo/v:0" assert x.op.name == "foo/bar/add"
參考資料:
1. https://www.tensorflow.org/how_tos/variable_scope/