TensorFlow的變量管理:變量作用域機制


在深度學習中,你可能需要用到大量的變量集,而且這些變量集可能在多處都要用到。例如,訓練模型時,訓練參數如權重(weights)、偏置(biases)等已經定下來,要拿到驗證集去驗證,我們自然希望這些參數是同一組。以往寫簡單的程序,可能使用全局限量就可以了,但在深度學習中,這顯然是不行的,一方面不便管理,另外這樣一來代碼的封裝性受到極大影響。因此,TensorFlow提供了一種變量管理方法:變量作用域機制,以此解決上面出現的問題。

TensorFlow的變量作用域機制依賴於以下兩個方法,官方文檔中定義如下:

 

[plain]  view plain  copy
 
  1. tf.get_variable(name, shape, initializer): Creates or returns a variable with a given name.建立或返回一個給定名稱的變量  
  2. tf.variable_scope( scope_name): Manages namespaces for names passed to tf.get_variable(). 管理傳遞給tf.get_variable()的變量名組成的命名空間  

 

先說說tf.get_variable(),這個方法在建立新的變量時與tf.Variable()完全相同。它的特殊之處在於,他還會搜索是否有同名的變量。創建變量用法如下:

 

[plain]  view plain  copy
 
  1. with tf.variable_scope("foo"):  
  2.     with tf.variable_scope("bar"):  
  3.         v = tf.get_variable("v", [1])  
  4.         assert v.name == "foo/bar/v:0"  


而tf.variable_scope(scope_name),它會管理在名為scope_name的域(scope)下傳遞給tf.get_variable的所有變量名(組成了一個變量空間),根據規則確定這些變量是否進行復用。這個方法最重要的參數是reuse,有None,tf.AUTO_REUSE與True三個選項。具體用法如下:

 

 

  1. reuse的默認選項是None,此時會繼承父scope的reuse標志。
  2. 自動復用(設置reuse為tf.AUTO_REUSE),如果變量存在則復用,不存在則創建。這是最安全的用法,在使用新推出的EagerMode時reuse將被強制為tf.AUTO_REUSE選項。用法如下:
    [plain]  view plain  copy
     
    1. def foo():  
    2.   with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):  
    3.     v = tf.get_variable("v", [1])  
    4.   return v  
    5.   
    6. v1 = foo()  # Creates v.  
    7. v2 = foo()  # Gets the same, existing v.  
    8. assert v1 == v2  
  3. 復用(設置reuse=True):
    [plain]  view plain  copy
     
    1. with tf.variable_scope("foo"):  
    2.   v = tf.get_variable("v", [1])  
    3. with tf.variable_scope("foo", reuse=True):  
    4.   v1 = tf.get_variable("v", [1])  
    5. assert v1 == v  
  4. 捕獲某一域並設置復用(scope.reuse_variables()):
    [plain]  view plain  copy
     
    1. with tf.variable_scope("foo") as scope:  
    2.   v = tf.get_variable("v", [1])  
    3.   scope.reuse_variables()  
    4.   v1 = tf.get_variable("v", [1])  
    5. assert v1 == v  

    1)非復用的scope下再次定義已存在的變量;或2)定義了復用但無法找到已定義的變量,TensorFlow都會拋出錯誤,具體如下:
[plain]  view plain  copy
 
  1. with tf.variable_scope("foo"):  
  2.     v = tf.get_variable("v", [1])  
  3.     v1 = tf.get_variable("v", [1])  
  4.     #  Raises ValueError("... v already exists ...").  
  5.   
  6.   
  7. with tf.variable_scope("foo", reuse=True):  
  8.     v = tf.get_variable("v", [1])  
  9.     #  Raises ValueError("... v does not exists ...").  
 
轉自: https://blog.csdn.net/zbgjhy88/article/details/78960388


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM