『TensorFlow』使用集合collection控制variables


Variable

Tensorflow使用Variable類表達、更新、存儲模型參數。

  • Variable是在可變更的,具有保持性的內存句柄,存儲着Tensor
  • 在整個session運行之前,圖中的全部Variable必須被初始化
    • Variable的值在sess.run(init)之后就確定了
    • Tensor的值要在sess.run(x)之后才確定
  • 創建的Variable被添加到默認的collection

 

tf.GraphKeys中包含了所有默認集合的名稱,可以通過查看__dict__發現具體集合。

tf.GraphKeys.GLOBAL_VARIABLES:global_variables被收集在名為tf.GraphKeys.GLOBAL_VARIABLEScolletion中,包含了模型中的通用參數

tf.GraphKeys.TRAINABLE_VARIABLES:tf.Optimizer默認只優化tf.GraphKeys.TRAINABLE_VARIABLES中的變量。

函數 集合名 意義
tf.global_variables() GLOBAL_VARIABLES

存儲和讀取checkpoints時,使用其中所有變量

跨設備全局變量集合

tf.trainable_variables() TRAINABLE_VARIABLES

訓練時,更新其中所有變量

存儲需要訓練的模型參數的變量集合

tf.moving_average_variables() MOVING_AVERAGE_VARIABLES

ExponentialMovingAverage對象會生成此類變量

實用指數移動平均的變量集合

tf.local_variables() LOCAL_VARIABLES

global_variables()之外,需要用tf.init_local_variables()初始化

進程內本地變量集合

tf.model_variables() MODEL_VARIABLES

 Key to collect model variables defined by layers.

進程內存儲的模型參數的變量集合

  QUEUE_RUNNERS 並非存儲variables,存儲處理輸入的QueueRunner
  SUMMARIES 並非存儲variables,存儲日志生成相關張量

除了上表中的函數外(上表中最后兩個集合並非變量集合,為了方便一並放在這里),還可以使用tf.get_collection(集合名)獲取集合中的變量,不過這個函數更多與tf.get_collection(集合名)搭配使用,操作自建集合。

另,slim.get_model_variables()與tf.model_variables()功能近似。

 

Summary

Summary被收集在名為tf.GraphKeys.UMMARIEScolletion中,

  • Summary是對網絡中Tensor取值進行監測的一種Operation
  • 這些操作在圖中是“外圍”操作,不影響數據流本身
  • 調用tf.scalar_summary系列函數時,就會向默認的collection中添加一個Operation

 

自定義集合

除了默認的集合,我們也可以自己創造collection組織對象。網絡損失就是一類適宜對象。

tensorflow中的Loss提供了許多創建損失Tensor的方式。

x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)

x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)

創建損失不會自動添加到集合中,需要手工指定一個collection

tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)

創建完成后,可以統一獲取所有損失,losses是個Tensor類型的list:

losses = tf.get_collection('losses')

一種常見操作把所有損失累加起來得到一個Tensor

loss_total = tf.add_n(losses)

 執行操作可以得到損失取值:

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)

 實際上,如果使用TF-Slim包的losses系列函數創建損失,會自動添加到名為”losses”的collection中。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM