TensorFlow創建變量


1 使用tf.Variable函數創建變量

tf.Variable(initial_value=None,trainable=True,collections=None,validate_shape=True,caching_device=None,name=None,variable_def=None,dtype=None,expected_shape=None,import_scope=None)

函數功能:

  創建一個新的變量,變量的值是initial_value,創建的變量會被添加到[GraphKeys.GLOBAL_VARIABLES]默認的計算圖列表中,如果trainable被設置為True,這個變量還會被添加到GraphKeys.TRAINABLE_VARIABLES計算圖的集合中。

參數:

  • initial_value:默認值是None,張量或者是一個python對象可以轉成張量,這個initial_value是初始化變量的值。它必須有一個特殊的shape,除非validate_shape設置為False。
  • trainable:默認的是True,變量還會被添加到GraphKeys.TRAINABLE_VARIABLES計算圖集合中。
  • collections:變量會被添加到這個集合中,默認的集合是[GraphKeys.GLOBAL_VARIABLES]。
  • validate_shape:如果是False,允許這個變量被初始化一個不知道shape。默認的是True,這個initial_value的shape必須是知道的。
  • name:變量的名字。
  • dypte:變量的類型,小數的默認是float32,整數默認是int32。

2 使用tf.get_variable函數創建變量

tf.get_variable(name,shape=None,dtype=None,initializer=None,regularizer=None,trainable=True,collections=None,caching_device=None,partitioner=None,validate_shape=True,use_resource=None,custom_getter=None)
函數功能:

  根據變量的名稱來獲取變量或者創建變量。

參數:

  • name:變量的名稱(必選)。
  • shape:變量的shape。
  • dtype:變量的數據類型。
  • initializer:變量的初始化值。

2.1 根據變量的名稱創建變量

b = tf.get_variable(name="b", initializer=[1., 2., 3.]) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) print(b.eval(session=sess)) #[ 1. 2. 3.]
print(b.dtype) #<dtype: 'float32_ref'>

使用tf.get_variable創建變量的時候,如果不指定name,會報TypeError: get_variable() missing 1 required positional argument: 'name'

2.2 根據變量的名稱獲取變量

with tf.variable_scope("f"): #初始化一個變量名稱為c的變量
        c = tf.get_variable(name="c",shape=[3],initializer=tf.constant_initializer([1,2,3])) with tf.variable_scope("f",reuse=True): d = tf.get_variable(name="c",shape=[3]) sess = tf.Session() init = tf.initialize_all_variables() sess.run(init) print(d.eval(session=sess)) #[ 1. 2. 3.]
        print(c.eval(session=sess)) #[ 1. 2. 3.]
        print(d == c) #True

  在使用tf.get_variable()根據變量的名稱來獲取已經生成變量的時候,需要通過tf.variable_scope函數來生成一個上下文管理器,並明確指定在這個上下文管理器中。獲取變量值的時候,需要將上下文管理器中的reuse設置為True,才能直接獲取已經聲明的變量,如果不設置reuse會報錯。需要注意的是,如果變量名在上下文管理器中已經存在,在獲取的時候,如果不將reuse設置為True則會報錯。同理,如果上下文管理器中不存在變量名,在使用reuse=True獲取變量值的時候,也會報錯。
補充:

(1)tf.variable_scope的嵌套

with tf.variable_scope("a"):#默認是False
  #查看上下文管理器中的reuse的值
  print(tf.get_variable_scope().reuse) #False
  with tf.variable_scope("b",reuse=True):     print(tf.get_variable_scope().reuse) #True
       #如果reuse是默認的則保持和上一層的reuse值一樣
       with tf.variable_scope("c"):   print(tf.get_variable_scope().reuse) #True
    print(tf.get_variable_scope().reuse) #False

(2)上下文管理器與變量名

#沒有上下文管理器
a = tf.get_variable(name="a",shape=[2],initializer=tf.constant_initializer([1,2])) print(a.name) #a:0,a就是變量名

#聲明上下文管理器 with tf.variable_scope("f"):   b = tf.get_variable(name="b",shape=[2],initializer=tf.constant_initializer([1,2]))   print(b.name) #f/b:0,f代表的是上下文管理器的名稱,b代表的是變量的名稱   #嵌套上下文管理器   with tf.variable_scope("g"):     c = tf.get_variable(name="c",shape=[2],initializer=tf.constant_initializer([1,2]))     print(c.name)#f/g/c:0

(3)通過上下文管理器和變量名來獲取變量

#通過帶上下文管理器名稱和變量名來獲取變量
with tf.variable_scope("",reuse=True):   d = tf.get_variable(name="f/b")   print(d == b)  #True
  e = tf.get_variable(name="f/g/c")   print(e == c)  #True

轉:修煉之路的博客(侵刪)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM