tensorflow學習筆記-Graph、GraphDef、MetaGraphDef【轉】


原文 - Tensorflow框架實現中的“三”種圖 - 知乎
原文 - https://www.aiuai.cn/aifarm701.html

圖(Graph) 是 TensorFlow 用於表達計算任務的一個核心概念.

從前端(python) 描述神經網絡的結構,到后端在多機和分布式系統上部署,到底層 Device(CPU、GPU、TPU)上運行,都是基於圖來完成.

然而在實際使用過程中遇到了三對API,

[1] - tf.train.Saver()/saver.restore()

[2] - export_meta_graph/Import_meta_graph

[3] - tf.train.write_graph()/tf.Import_graph_def()

它們都是用於對圖的保存和恢復.

同一個計算框架,為什么需要三對不同的API呢?他們保存/恢復的圖在使用時又有什么區別呢?

初學的時候,常常鬧不清楚他們的區別,以至常常寫出了錯誤的程序,經過一番研究,本文中對Tensorflow中圍繞Graph的核心概念進行了總結.

1. Graph

首先介紹一下關於 TensorFlow 中 Graph 和它的序列化表示 Graph_def.

在 TensorFlow 官方文檔中,Graph 被定義為 “一些 Operation 和 Tensor 的集合”.

例如表達如下的一個計算的 python代碼,

import tensorflow as tf

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
d = a*b+c
e = d*2

就會生成相應的一張圖,在Tensorboard中看到的圖大概如圖:

其中,每一個圓圈表示一個 Operation(輸入處為Placeholder),橢圓到橢圓的邊為Tensor,箭頭的指向表示了這張圖 Operation 輸入輸出 Tensor 的傳遞關系.

在真實的 TensorFlow 運行中,Python 構建的“圖Graph” 並不是啟動一個 Session 之后始終不變的. 因為 TensorFlow 在運行時,真實的計算會被分配到多CPUs,或 GPUs,或 ARM 等,以進行高性能/能效的計算. 單純使用 Python 肯定是無法有效完成的.

實際上,TensorFlow 是首先將 python 代碼所描繪的圖轉換(即“序列化”)成 Protocol Buffer,再通過 C/C++/CUDA 運行 Protocol Buffer 所定義的圖. (Protocol Buffer 可參考:https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/).

2. GraphDef

從 python Graph中序列化出來的圖就叫做 GraphDef (這是一種不嚴格的說法,先這樣進行理解).

而 GraphDef 又是由許多叫做 NodeDef 的 Protocol Buffer 組成. 在概念上 NodeDef 與(Python Graph 中的) Operation 相對應.

如下就是 GraphDef 的 ProtoBuf,由許多node 組成的圖表示. 這是與上文 Python 圖對應的 GraphDef:

node {
  name: "Placeholder"    # 注:這是一個叫做 "Placeholder" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "Placeholder_1" # 注:這是一個叫做 "Placeholder_1" 的node
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "mul"          # 注:一個 Mul(乘法)操作
  op: "Mul"
  input: "Placeholder" # 使用上面的node(即Placeholder和Placeholder_1)
  input: "Placeholder_1" # 作為這個Node的輸入
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

以上三個 NodeDef 定義了兩個 Placeholde r和一個Multiply.

Placeholder 通過 attr(attribute的縮寫)來定義數據類型和 Tensor 的形狀.

Multiply 通過 input 屬性定義了兩個 placeholder 作為其輸入.

無論是 Placeholder 還是 Multiply 都沒有關於輸出(output)的信息.

其實 Tensorflow 中都是通過 Input 來定義 Node 之間的連接信息.

那么既然 tf.Operation 的序列化 ProtoBuf 是 NodeDef,那么 tf.Variable 呢?在這個 GraphDef 中只有網絡的連接信息,卻沒有任何 Variables呀?

沒錯,Graphdef 中不保存任何 Variable 的信息,所以如果從 graph_def 來構建圖並恢復訓練的話,是不能成功的.

如,

with tf.Graph().as_default() as graph:
  tf.import_graph_def("graph_def_path")
  saver= tf.train.Saver()
  with tf.Session() as sess:
    tf.trainable_variables()

其中 tf.trainable_variables() 只會返回一個空的list.tf.train.Saver() 也會報告 no variables to save.

然而,在實際線上 inference 中,通常就是使用 GraphDef. 但,GraphDef 中連 Variable都沒有,怎么存儲 weight 呢?

原來, GraphDef 雖然不能保存 Variable,但可以保存 Constant. 通過 tf.constant 將 weight 直接存儲在 NodeDef 里,tensorflow 1.3.0 版本也提供了一套叫做 freeze_graph 的工具來自動的將圖中的 Variable 替換成 constant 存儲在 GraphDef 里面,並將該圖導出為 Proto.

https://www.tensorflow.org/extend/tool_developers/https://www.tensorflow.org/mobile/prepare_models

tf.train.write_graph()/tf.Import_graph_def() 就是用來進行 GraphDef 讀寫的API. 那么,我們怎么才能從序列化的圖中,得到 Variables呢?這就要學習下一個重要概念,MetaGraph.

3. MetaGraph

Meta graph 的官方解釋是:一個 Meta Graph 由一個計算圖和其相關的元數據構成, 其包含了用於繼續訓練,實施評估和(在已訓練好的的圖上)做前向推斷的信息.

A MetaGraph consists of both a computational graph and its associated metadata.
A MetaGraph contains the information required to continue training, perform evaluation, or run inference on a previously trained graph.

From https://www.tensorflow.org/versions/r1.1/programmers_guide/

這一段看的雲里霧里,不過這篇文章(https://www.tensorflow.org/versions/r1.1/programmers_guide/meta_graph)進一步解釋說,Meta Graph在具體實現上就是一個 MetaGraphDef (同樣是由 Protocol Buffer來定義的). 其包含了四種主要的信息,根據Tensorflow官網,這四種 Protobuf 分別是:

[1] - MetaInfoDef,存一些元信息(比如版本和其他用戶信息)
[2] - GraphDef, MetaGraph 的核心內容之一
[3] - SaverDef,圖的Saver信息(比如最多同時保存的check-point數量,需保存的Tensor名字等,但並不保存Tensor中的實際內容)
[4] - CollectionDef,任何需要特殊注意的 Python 對象,需要特殊的標注以方便import_meta_graph 后取回(如 train_op, prediction 等等)

在以上四種 ProtoBuf 里面,[1] 和 [3] 都比較容易理解,[2] 剛剛總結過. 這里特別要講一下 Collection(CollectionDef是對應的ProtoBuf).

TensorFlow 中並沒有一個官方的定義說 collection 是什么. 簡單的理解,它就是為了方別用戶對圖中的操作和變量進行管理,而創建的一個概念. 它可以說是一種“集合”,通過一個 key (string類型) 來對一組 Python 對象進行命名的集合. 這個 key 既可以是TensorFlow 在內部定義的一些 key,也可以是用戶自己定義的名字(string).

TensorFlow 內部定義了許多標准 Key,全部定義在了 tf.GraphKeys 這個類中. 其中有一些常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES 等等. tf.trainable_variables() 與 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 是等價的;tf.global_variables() 與 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 是等價的.

集合類型 集合內容 使用環境
tf.GraphKeys.VARIABLES 神經網絡參數
tf.GraphKeys.TRAINABLE_VARIABLES 模型訓練,生產模型可視化內容
tf.GraphKeys.SUMMARIES 日志生成相關張量 計算可視化
tf.GraphKeys.QUEUE_RUNNER 處理輸入的QueueRunner 輸入處理
tf.MOVING_AVERAGE_BARIABLES 所有計算了滑動平均值的變量 計算變量滑動平均值

對於用戶定義的 key,舉一個例子, 例如:

pred = model_network(X)
loss=tf.reduce_mean(…, pred, …)
train_op=tf.train.AdamOptimizer(lr).minimize(loss)

這樣一段 Tensorflow程序,用戶希望特別關注 pred, loss, train_op 這幾個操作,那么就可以使用如下代碼,將這幾個變量加入到 collection 中去. (假設我們將其命名為 “training_collection”)

tf.add_to_collection("training_collection", pred)
tf.add_to_collection("training_collection", loss)
tf.add_to_collection("training_collection", train_op)

並且可以通過 Train_collect = tf.get_collection(“training_collection”) 得到一個python list,其中的內容就是pred, loss, train_op 的 Tensor. 這通常是為了在一個新的 session 中打開這張圖時,方便我們獲取想要的操作. 比如我們可以直接通過 get_collection() 得到 train_op,然后通過 sess.run(train_op) 來開啟一段訓練,而無需重新構建 loss 和optimizer.

通過 export_meta_graph 保存圖,並且通過 add_to_collection 將 train_op 加入到 collection 中:

with tf.Session() as sess:
  pred = model_network(X)
  loss=tf.reduce_mean(…,pred, …)
  train_op=tf.train.AdamOptimizer(lr).minimize(loss)
  tf.add_to_collection("training_collection", train_op)
  Meta_graph_def = 
      tf.train.export_meta_graph(tf.get_default_graph(), 'my_graph.meta')

通過 import_meta_graph 將圖恢復(同時初始化為本 Session的 default 圖),並且通過 get_collection 重新獲得 train_op,以及通過 train_op 來開始一段訓練(sess.run() ).

with tf.Session() as new_sess:
  tf.train.import_meta_graph('my_graph.meta')
  train_op = tf.get_collection("training_collection")[0]
  new_sess.run(train_op)

更多的代碼例子可以在這篇文檔(https://www.tensorflow.org/api_guides/python/meta_graph)中的 Import a MetaGraph 章節中看到.

那么,從 Meta Graph 中恢復構建的圖可以被訓練嗎?是可以的. TensorFlow 的官方文檔 https://www.tensorflow.org/api_guides/python/meta_graph 說明了使用方法. 這里要特殊的說明一下,Meta Graph 中雖然包含 Variable 的信息,卻沒有 Variable 的實際值. 所以, 從Meta Graph 中恢復的圖,其訓練是從隨機初始化的值開始的. 訓練中 Variable的實際值都保存在 checkpoint 中,如果要從之前訓練的狀態繼續恢復訓練,就要從checkpoint 中 restore. 進一步讀一下 Export Meta Graph 的代碼,可以看到,事實上variables 並沒有被 export 到 meta_graph 中.

https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/training/saver.py (1872行)
https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/framework/meta_graph.py (829,845行)

export_meta_graph/Import_meta_graph 就是用來進行 Meta Graph 讀寫的API.

tf.train.saver.save() 在保存checkpoint的同時也會保存Meta Graph. 但是在恢復圖時,tf.train.saver.restore() 只恢復 Variable,如果要從MetaGraph恢復圖,需要使用 import_meta_graph. 這是其實為了方便用戶,有時我們不需要從MetaGraph恢復的圖,而是需要在 python 中構建神經網絡圖,並恢復對應的 Variable.

4. Checkpoint

Checkpoint 里全面保存了訓練某時間截面的信息,包括參數,超參數,梯度等等. tf.train.Saver()/saver.restore() 則能夠完完整整保存和恢復神經網絡的訓練.

Checkpoint 分為兩個文件保存Variable的二進制信息. ckpt 文件保存了Variable的二進制信息,index 文件用於保存 ckpt 文件中對應 Variable 的偏移量信息.

5. 總結

TensorFlow 三種 API 所保存和恢復的圖是不一樣的.

這三種圖是從 TensorFlow 框架設計的角度出發而定義的.

但是從用戶的角度來看,TensorFlow 文檔的寫作難免有些雲里霧里,弄不清他們的區別.需要讀一讀Tensorflow的代碼,做一些實驗來進行辨析.

簡而言之,TensorFlow 在前端 Python 中構建圖,並且通過將該圖序列化到 ProtoBuf GraphDef,以方便在后端運行. 在這個過程中,圖的保存、恢復和運行都通過 ProtoBuf 來實現. GraphDef,MetaGraph,以及Variable,Collection 和 Saver 等都有對應的 ProtoBuf 定義. ProtoBuf 的定義也決定了用戶能對圖進行的操作. 例如用戶只能找到 Node的前一個Node,卻無法得知自己的輸出會由哪個Node接收.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM