tensorflow中訓練后的模型是一個pb文件,proto 文件如下:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
通過解析pb文件即可以拿到訓練后的的權重信息。
with open(output_graph_path,"rb") as f: output_graph.ParseFromString(f.read())
graph是有node節點組成,遍歷所有的node 節點可以獲取到 訓練的權重信息以及shape大小:
for node in output_graph.node: print 'name:{}'.format(node.name) print 'shape:{},dtype;{}'.format(node.attr['value'].tensor.tensor_shape,node.attr['value'].tensor.dtype) if node.attr['value'].tensor.dtype != 1: continue print tensor_util.MakeNdarray(node.attr['value'].tensor)
graph是一個有向圖,由節點和有向邊組成,假設有如下計算表達式:t1=MatMul(input, W1)。
圖計算表達式包含三個節點,兩條邊,描述為文字形式如下:
TF 調用protobuf 解析方法,將graph 的字符串描述解析並生成grapdef實例,下一節查看graphdef的輸入和輸出,並嘗試將模型轉換為caffe模型
代碼在git地址:https://github.com/wudafucode/machine_learning/blob/master/showpb.py