Variable
Tensorflow使用Variable
類表達、更新、存儲模型參數。
Variable
是在可變更的,具有保持性的內存句柄,存儲着Tensor
- 在整個
session
運行之前,圖中的全部Variable
必須被初始化
Variable
的值在sess.run(init)之后就確定了Tensor
的值要在sess.run(x)之后才確定
- 創建的
Variable
被添加到默認的collection
中
tf.GraphKeys
中包含了所有默認集合的名稱,可以通過查看__dict__發現具體集合。
被收集在名為tf.GraphKeys.
GLOBAL_VARIABLES:global_variablestf.GraphKeys.GLOBAL_VARIABLES
的colletion
中,包含了模型中的通用參數
tf.GraphKeys.TRAINABLE_VARIABLES:
tf.Optimizer默認
只優化tf.GraphKeys.TRAINABLE_VARIABLES
中的變量。
函數 | 集合名 | 意義 |
---|---|---|
tf.global_variables() | GLOBAL_VARIABLES | 存儲和讀取checkpoints時,使用其中所有變量 跨設備全局變量集合 |
tf.trainable_variables() | TRAINABLE_VARIABLES | 訓練時,更新其中所有變量 存儲需要訓練的模型參數的變量集合 |
tf.moving_average_variables() | MOVING_AVERAGE_VARIABLES |
實用指數移動平均的變量集合 |
tf.local_variables() | LOCAL_VARIABLES | 在 進程內本地變量集合 |
tf.model_variables() | MODEL_VARIABLES | Key to collect model variables defined by layers. 進程內存儲的模型參數的變量集合 |
QUEUE_RUNNERS | 並非存儲variables,存儲處理輸入的QueueRunner | |
SUMMARIES | 並非存儲variables,存儲日志生成相關張量 |
除了上表中的函數外(上表中最后兩個集合並非變量集合,為了方便一並放在這里),還可以使用tf.get_collection(集合名)獲取集合中的變量,不過這個函數更多與tf.get_collection(集合名)搭配使用,操作自建集合。
另,slim.get_model_variables()與tf.model_variables()功能近似。
Summary
Summary
被收集在名為tf.GraphKeys.
UMMARIES
的colletion
中,
Summary
是對網絡中Tensor
取值進行監測的一種Operation
- 這些操作在圖中是“外圍”操作,不影響數據流本身
- 調用tf.scalar_summary系列函數時,就會向默認的
collection
中添加一個Operation
自定義集合
除了默認的集合,我們也可以自己創造collection
組織對象。網絡損失就是一類適宜對象。
tensorflow中的Loss提供了許多創建損失Tensor
的方式。
x1 = tf.constant(1.0) l1 = tf.nn.l2_loss(x1) x2 = tf.constant([2.5, -0.3]) l2 = tf.nn.l2_loss(x2)
創建損失不會自動添加到集合中,需要手工指定一個collection
:
tf.add_to_collection("losses", l1) tf.add_to_collection("losses", l2)
創建完成后,可以統一獲取所有損失,losses
是個Tensor
類型的list:
losses = tf.get_collection('losses')
一種常見操作把所有損失累加起來得到一個Tensor
:
loss_total = tf.add_n(losses)
執行操作可以得到損失取值:
sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) losses_val = sess.run(losses) loss_total_val = sess.run(loss_total)
實際上,如果使用TF-Slim包的losses系列函數創建損失,會自動添加到名為”losses”的collection
中。