TensorFlow中的variable_scope


學了tf比較長一段時間了,一直沒有搞懂tf中的variable_scope的用法。感覺有些知識點很零碎,這次看了一本書(質量比想象中的要好很多啊),整體的回顧一下tf。

1. tf變量管理


tf提供了通過變量名稱來創建或者獲取一個變量的機制。通過這個機制,在不同的函數中可以直接通過變量的名字來使用變量,而不需要將變量通過參數的形式到處傳遞(確實是一個痛點)。tf中通過變量名稱獲取變量的機制主要是通過tf.get_variabletf.variable_scope函數實現的。

除了tf.Variable函數,tf還提供了tf.get_variable函數來創建或者獲取變量。當tf.get_variable用於創建變量時,它和tf.Variable的功能是基本等價的:

# 下面這兩個定義是等價的
v = tf.get_variable("v", shape=[1], initializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0, shape=[1]), name="v")

兩者最大的區別在於指定變量名稱的參數。對於tf.Variable函數,變量名稱是一個可選的參數。但是對於tf.get_variable函數,變量名稱是一個必填的參數。

2. variable_scope用法


  • 下面給出了一段代碼說明如何通過tf.variable_scope函數來控制tf.get_variable函數獲取己經創建過的變量:
# 在名字為foo的命名空間內創建名字為v的變量
with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))
    
# 因為在命名空間foo中已經存在名字為v的變量,所以以下代碼會報錯:
# with tf.variable_scope("foo"):
#     v = tf.get_variable("v")
    
# 在生成上下文管理器時,將參數reuse設置為True,這樣get_variable可以直接獲取已經聲明的變量
with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("v")
    print(v==v1) # 輸出True
    # 下面一句會報錯:
    # u = tf.get_variable("u", [1], initializer=tf.constant_initializer(2.0))

# reuse為True時,只能獲取已經創建的變量,除非改成tf.ATUO_REUSE
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    u = tf.get_variable("u", [1], initializer=tf.constant_initializer([2.0]))
  • tf.variable_scope函數是可以嵌套的:
with tf.variable_scope("root"):
    print(tf.get_variable_scope().reuse) # 輸出False
    
    with tf.variable_scope("foo", reuse=True):
        print(tf.get_variable_scope().reuse) # 輸出True
        
        with tf.variable_scope("bar"): # 輸出True(不指定的話會和上一層保持一致)
            print(tf.get_variable_scope().reuse)
            
    print(tf.get_variable_scope().reuse) # 輸出False(回到了最外層)
  • 使用tf.variable_scope管理變量名稱:
v1 = tf.get_variable("v", [1])
print(v1.name) # 輸出v:0

with tf.variable_scope("foo"):
    v2 = tf.get_variable("v", [1])
    print(v2.name) # 輸出foo/v:0

with tf.variable_scope("foo"):
    with tf.variable_scope("bar"):
        v3 = tf.get_variable("v", [1])
        print(v3.name) # 輸出foo/bar/v:0

with tf.variable_scope("foo", reuse=True):
    v5 = tf.get_variable("bar/v")
    print(v5==v3) # 輸出True
    v6 = tf.get_variable("v")
    print(v6==v2) # 輸出True

with tf.variable_scope("rick"):
    v = tf.get_variable("v", [1])
    print(v.name) # 輸出rick/v:0
    print(v==v1) # 輸出False

總結一下:

  1. 如果已經創建過一個name為"foo"的variable_scope,再次使用with tf.variable_scope("foo")時,不能夠用get_variable獲取到"foo"中的同名變量。
  2. 對於1,如果是with tf.variable_scope("foo", reuse=True)的話,可以獲取到同名變量,但是無法創建新的變量,除非將reuse設置為tf.AUTO_REUSE
  3. 如果variable_scope的name和之前都不同的話,且reuse=False,那么可以用任意名稱創建變量(如"foo"和"rick"的對比)。
  4. variable_scope就像一個維護變量的空間,reuse在variable_scope創建時一般是False,變量使用時再指定為True。

3. 具體實例


以一個mnist的小程序為例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE = 784
OUTPUT_NODE = 10

LAYER1_NODE = 500
BATCH_SIZE = 100

LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 10000
MOVING_AVERAGE_DECAY = 0.99

# 使用variable_scope管理變量
def inference(input_tensor, reuse=False):
    with tf.variable_scope("layer1", reuse=reuse):
        weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE], initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.1))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)

    with tf.variable_scope("layer2", reuse=reuse):
        weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE], initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.1))
        layer2 = tf.matmul(layer1, weights) + biases
    return layer2

def train(mnist):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")
    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")

    y = inference(x)

    global_step = tf.Variable(0, trainable=False)
    
    # 計算損失函數
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)

    # 注意這一段,weights1和weights2都在inference中創建,但是L2正則化需要用到這個變量。
    # 使用variable_scope就能很方便的獲取到這兩個變量:
    with tf.variable_scope("", reuse=True):
        weights1 = tf.get_variable("layer1/weights")
        weights2 = tf.get_variable("layer2/weights")
    
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization = regularizer(weights1) + regularizer(weights2)
    loss = cross_entropy_mean + regularization

    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, 
                                                global_step, 
                                                mnist.train.num_examples/BATCH_SIZE,
                                                LEARNING_RATE_DECAY)

    train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
        
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
        test_feed = {x: mnist.test.images, y_: mnist.test.labels}
        
        for i in range(TRAINING_STEPS):
            if (i+1) % 1000 == 0:
                validate_acc = sess.run(accuracy, feed_dict=validate_feed)
                print("After %d training steps, validation accuracy using average model is %f" % (i+1, validate_acc))
                
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op, feed_dict={x:xs, y_:ys})
            
        test_acc = sess.run(accuracy, feed_dict=test_feed)
        print("After %d training steps, validation accuracy using average model is %f" % (TRAINING_STEPS, test_acc))
        

def main(argv=None):
    mnist = input_data.read_data_sets("./mnist", one_hot=True)
    train(mnist)
    
if __name__ == "__main__":
    main()

參考資料

  1. TensorFlow:實戰Google深度學習框架(第2版)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM