tensorflow學習筆記——模型持久化的原理,將CKPT轉為pb文件,使用pb模型預測


  由題目就可以看出,本節內容分為三部分,第一部分就是如何將訓練好的模型持久化,並學習模型持久化的原理,第二部分就是如何將CKPT轉化為pb文件,第三部分就是如何使用pb模型進行預測。

  這里新增一個h5轉tflite,h5轉pb,pb轉 tflite的文件代碼,代碼直接展示,不寫什么了,我感覺其實也沒有必要寫什么了,該說的都說了,只不過h5模型是Keras中訓練的。

一,模型持久化

  為了讓訓練得到的模型保存下來方便下次直接調用,我們需要將訓練得到的神經網絡模型持久化。下面學習通過TensorFlow程序來持久化一個訓練好的模型,並從持久化之后的模型文件中還原被保存的模型,然后學習TensorFlow持久化的工作原理和持久化之后文件中的數據格式。

1,持久化代碼實現

  TensorFlow提供了一個非常簡單的API來保存和還原一個神經網絡模型。這個API就是 tf.train.Saver 類。使用 tf.train.saver()  保存模型時會產生多個文件,會把計算圖的結構和圖上參數取值分成了不同的文件存儲。這種方式是在TensorFlow中是最常用的保存方式。

  下面代碼給出了保存TensorFlow計算圖的方法:

#_*_coding:utf-8_*_
import tensorflow as tf
import os

# 聲明兩個變量並計算他們的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
# 聲明 tf.train.Saver類用於保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # 將模型保存到model.ckpt文件中
    model_path = 'model/model.ckpt'
    saver.save(sess, model_path)

  上面的代碼實現了持久化一個簡單的TensorFlow模型的功能。在這段代碼中,通過saver.save 函數將TensorFlow模型保存到了 model/model.path 文件中。TensorFlow模型一般會保存在后綴為 .ckpt 的文件中,雖然上面的程序只指定了一個文件路徑,但是這個文件目錄下面會出現三個文件。這是因為TensorFlow會將計算圖的結構和圖上參數取值分開保存。

  運行上面代碼,我們查看model文件里面的文件如下:

   下面解釋一下文件分別是干什么的:

  • checkpoint文件是檢查點文件,文件保存了一個目錄下所有模型文件列表。
  • model.ckpt.data文件保存了TensorFlow程序中每一個變量的取值
  • model.ckpt.index文件則保存了TensorFlow程序中變量的索引
  • model.ckpt.meta文件則保存了TensorFlow計算圖的結構(可以簡單理解為神經網絡的網絡結構),該文件可以被 tf.train.import_meta_graph 加載到當前默認的圖來使用。

      下面代碼給出加載這個模型的方法:

#_*_coding:utf-8_*_
import tensorflow as tf

#使用和保存模型代碼中一樣的方式來聲明變量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    # 加載已經保存的模型,並通過已經保存的模型中的變量的值來計算加法
    model_path = 'model/model.ckpt'
    saver.restore(sess, model_path)
    print(sess.run(result))

# 結果如下:[3.]

  這段加載模型的代碼基本上和保存模型的代碼是一樣的。在加載模型的程序中也是先定義了TensorFlow計算圖上所有運算,並聲明了一個 tf.train.Saver類。兩段代碼唯一不同的是,在加載模型的代碼中沒有運行變量的初始化過程,而是將變量的值通過已經保存的模型加載出來。如果不希望重復定義圖上的運算,也可以直接加載已經持久化的圖,以下代碼給出一個樣例:

import tensorflow as tf

# 直接加載持久化的圖
model_path = 'model/model.ckpt'
model_path1 = 'model/model.ckpt.meta'
saver = tf.train.import_meta_graph(model_path1)

with tf.Session() as sess:
    saver.restore(sess, model_path)
    # 通過張量的的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

# 結果如下:[3.]

  其上面給出的程序中,默認保存和加載了TensorFlow計算圖上定義的所有變量。但是有時可能只需要保存或者加載部分變量。比如,可能有一個之前訓練好的五層神經網絡模型,現在想嘗試一個六層神經網絡,那么可以將前面五層神經網絡中的參數直接加載到新的模型,而僅僅將最后一層神經網絡重新訓練。

  為了保存或者加載部分變量,在聲明 tf.train.Saver 類時可以提供一個列表來指定需要保存或者加載的變量。比如在加載模型的代碼中使用 saver = tf.train.Saver([v1]) 命令來構建 tf.train.Saver 類,那么只有變量 v1 會被加載進來。如果運行修改后只加載了 v1 的代碼會得到變量未初始化的錯誤:

tensorflow.python.framework.errors.FailedPreconditionError:Attempting to 
use uninitialized value v2

  因為 v2 沒有被加載,所以v2在運行初始化之前是沒有值的。除了可以選取需要被加載的變量,tf.train.Saver 類也支持在保存或者加載時給變量重命名。

  下面給出一個簡單的樣例程序說明變量重命名是如何被使用的。

import tensorflow as tf

# 這里聲明的變量名稱和已經保存的模型中變量的的名稱不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')

# 如果直接使用 tf.train.Saver() 來加載模型會報變量找不到的錯誤,下面顯示了報錯信息
# tensorflow.python.framework.errors.FailedPreconditionError:Tensor name 'other-v2'
# not found in checkpoint file model/model.ckpt

# 使用一個字典來重命名變量就可以加載原來的模型了
# 這個字典指定了原來名稱為 v1 的變量現在加載到變量 v1中(名稱為 other-v1)
# 名稱為v2 的變量加載到變量 v2中(名稱為 other-v2)
saver = tf.train.Saver({'v1': v1, 'v2': v2})

  在這個程序中,對變量 v1 和 v2 的名稱進行了修改。如果直接通過 tf.train.Saver 默認的構造函數來加載保存的模型,那么程序會報變量找不到的錯誤,因為保存時候的變量名稱和加載時變量的名稱不一致。為了解決這個問題,Tensorflow 可以通過字典(dictionary)將模型保存時的變量名和需要加載的變量聯系起來。這樣做的主要目的之一就是方便使用變量的滑動平均值。在之前介紹了使用變量的滑動平均值可以讓神經網絡模型更加健壯(robust)。在TensorFlow中,每一個變量的滑動平均值是通過影子變量維護的,所以要獲取變量的滑動平均值實際上就是獲取這個影子變量的取值。如果在加載模型時將影子變量映射到變量本身,那么在使用訓練好的模型時就不需要再調用函數來獲取變量的滑動平均值了。這樣就大大方便了滑動平均模型的時域。下面代碼給出了一個保存滑動平均模型的樣例:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name='v')
# 在沒有申明滑動平均模型時只有一個變量 v,所以下面語句只會輸出 v:0
for variables in tf.global_variables():
    print(variables.name)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申明滑動平均模型之后,TensorFlow會自動生成一個影子變量 v/ExponentialMovingAverage
# 於是下面的語句會輸出 v:0 和 v/ExponentialMovingAverage:0
for variables in tf.global_variables():
    print(variables.name)

saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存時,TensorFlow會將v:0 和 v/ExponentialMovingAverage:0 兩個變量都保存下來
    saver.save(sess, 'model/modeltest.ckpt')
    print(sess.run([v, ema.average(v)]))
    # 輸出結果 [10.0, 0.099999905]

  下面代碼給出了如何通過變量重命名直接讀取變量的滑動平均值。從下面程序的輸出可以看出,讀取的變量 v 的值實際上是上面代碼中變量 v 的滑動平均值。通過這個方法,就可以使用完全一樣的代碼來計算滑動平均模型前向傳播的結果:

v = tf.Variable(0, dtype=tf.float32, name='v')
# 通過變量重命名將原來變量v的滑動平均值直接賦值給 V
saver = tf.train.Saver({'v/ExponentialMovingAverage': v})
with tf.Session() as sess:
    saver.restore(sess, 'model/modeltest.ckpt')
    print(sess.run(v))
    # 輸出 0.099999905  這個值就是原來模型中變量 v 的滑動平均值

  為了方便加載時重命名滑動平均變量,tf.train.ExponentialMovingAverage 類提供了 variables_tp_restore 函數來生成 tf.train.Saver類所需要的變量重命名字典,一下代碼給出了 variables_to_restore 函數的使用樣例:

v = tf.Variable(0, dtype=tf.float32, name='v')
ema = tf.train.ExponentialMovingAverage(0.99)

# 通過使用 variables_to_restore 函數可以直接生成上面代碼中提供的字典
# {'v/ExponentialMovingAverage': v}
# 下面代碼會輸出 {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
print(ema.variables_to_restore())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, 'model/modeltest.ckpt')
    print(sess.run(v))
    # 輸出 0.099999905  即原來模型中變量 v 的滑動平均值

  使用 tf.train.Saver 會保存進行TensorFlow程序所需要的全部信息,然后有時並不需要某些信息。比如在測試或者離線預測時,只需要知道如何從神經網絡的輸出層經過前向傳播計算得到輸出層即可,而不需要類似於變量初始化,模型保存等輔助接點的信息。而且,將變量取值和計算圖結構分成不同的文件存儲有時候也不方便,於是TensorFlow提供了 convert_variables_to_constants 函數,通過這個函數可以將計算圖中的變量及其取值通過常量的方式保存,這樣整個TensorFlow計算圖可以統一存放在一個文件中,該方法可以固化模型結構,而且保存的模型可以移植到Android平台。

convert_variables_to_constants固化模型結構

  下面給出一個樣例:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    # 導出當前計算圖的GraphDef部分,只需要這一步就可以完成從輸入層到輸出層的過程
    graph_def = tf.get_default_graph().as_graph_def()

    # 將圖中的變量及其取值轉化為常量,同時將圖中不必要的節點去掉
    # 在下面,最后一個參數['add']給出了需要保存的節點名稱
    # add節點是上面定義的兩個變量相加的操作
    # 注意這里給出的是計算節點的的名稱,所以沒有后面的 :0
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, (['add']))
    # 將導出的模型存入文件
    with tf.gfile.GFile('model/combined_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())

  通過下面的程序可以直接計算定義加法運算的結果,當只需要得到計算圖中某個節點的取值時,這提供了一個更加方便的方法,以后將使用這種方法來使用訓練好的模型完成遷移學習。

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = 'model/combined_model.pb'
    # 讀取保存的模型文件,並將文件解析成對應的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # 將graph_def 中保存的圖加載到當前的圖中,
    # return_elements = ['add: 0'] 給出了返回的張量的名稱
    # 在保存的時候給出的是計算節點的名稱,所以為add
    # 在加載的時候給出的張量的名稱,所以是 add:0
    result = tf.import_graph_def(graph_def, return_elements=['add: 0'])
    print(sess.run(result))
    # 輸出 [array([3.], dtype=float32)]

 

 2,持久化原理及數據格式

  上面學習了當調用 saver.save 函數時,TensorFlow程序會自動生成四個文件。TensorFlow模型的持久化就是通過這個四個文件完成的。這里我們詳細學習一下這個三個文件中保存的內容以及數據格式。

  TensorFlow是一個通過圖的形式來表述計算的編程系統,TensorFlow程序中所有計算都會被表達為計算圖上的節點。TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的信息以及運行計算圖中節點所需要的元數據。TensorFlow中元圖是由 MetaGraphDef Protocol Buffer 定義的。MetaGraphDef 中的內容就構成了TensorFlow 持久化的第一個文件,以下代碼給出了MetaGraphDef類型的定義:

message MetaGraphDef{
    MeatInfoDef meta_info_def = 1;
    GraphDef graph_def = 2;
    SaverDef saver_def = 3;
    map<string,CollectionDef> collection_def = 4;
    map<string,SignatureDef> signature_def = 5;
}

  從上面代碼中可以看到,元圖中主要記錄了五類信息,下面結合變量相加樣例的持久化結果,逐一介紹MetaGraphDef類型的每一個屬性中存儲的信息。保存 MetaGraphDef 信息的文件默認為以 .meta 為后綴名,在上面,文件 model.ckpt.meta 中存儲的就是元圖的數據。直接運行其樣例得到的是一個二進制文件,無法直接查看。為了方便調試,TensorFlow提供了 export_meta_graph 函數,這函數支持以json格式導出 MetaGraphDef Protocol Buffer。下面代碼展示了如何使用這個函數:

import tensorflow as tf

# 定義變量相加的計算
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()
# 通過  export_meta_graph() 函數導出TensorFlow計算圖的元圖,並保存為json格式
saver.export_meta_graph('model/model.ckpt.meda.json', as_text=True)

  通過上面給出的代碼,我們可以將計算圖元圖以json的格式導出並存儲在 model.ckpt.meda.json 文件中。下面給出這個文件的大概內容:

  我們從JSON文件中可以看到確實是五類信息。下面結合這JSON文件的具體內容來學習一下TensorFlow中元圖存儲的信息。

1,meta_info_def屬性

  meta_info_def 屬性是通過MetaInfoDef定義的。它記錄了TensorFlow計算圖中的元數據以及TensorFlow程序中所有使用到的運算方法的信息,下面是 MetaInfoDef Protocol Buffer 的定義:

message MetaInfoDef{
    #saver沒有特殊指定,默認屬性都為空。meta_info_def屬性里只有stripped_op_list屬性不能為空。
    #該屬性不能為空
    string meta_graph_version = 1;
    #該屬性記錄了計算圖中使用到的所有運算方法的信息,該函數只記錄運算信息,不記錄計算的次數
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
}

  TensorFlow計算圖的元數據包括了計算圖的版本號(meta_graph_version屬性)以及用戶指定的一些標簽(tags屬性)。如果沒有在 saver中特殊指定,那么這些屬性都默認為空。

  在model.ckpt.meta.json文件中,meta_info_def 屬性里只有 stripped_op_list屬性是不為空的。stripped_op_list 屬性記錄了TensorFlow計算圖上使用到的所有運算方法的信息。注意stripped_op_list 屬性保存的是 TensorFlow 運算方法的信息,所以如果某一個運算在TensorFlow計算圖中出現了多次,那么在 stripped_op_list  也只會出現一次。比如在 model.ckpt.meta.jspm 文件的 stripped_op_list  屬性只有一個 Variable運算,但是這個運算在程序中被使用了兩次。 

  stripped_op_list 屬性的類型是  OpList。OpList 類型是一個 OpDef類型的列表,以下代碼給出了 OpDef 類型的定義:

message opDef{
    string name = 1;#定義了運算的名稱
    repeated ArgDef input_arg = 2; #定義了輸入,屬性是列表
    repeated ArgDef output_arg =3; #定義了輸出,屬性是列表
    repeated AttrDef attr = 4;#給出了其他運算的參數信息
    string summary = 5;
    string description = 6;
    OpDeprecation deprecation = 8;
    bool is_commutative = 18;
    bool is_aggregate = 16
    bool is_stateful = 17;
    bool allows_uninitialized_input = 19;
};

  OpDef 類型中前四個屬性定義了一個運算最核心的信息。OpDef 中的第一個屬性 name 定義了運算的名稱,這也是一個運算唯一的標識符。在TensorFlow計算圖元圖的其他屬性中,比如下面要學習的GraphDef屬性,將通過運算名稱來引用不同的運算。OpDef 的第二個和第三個屬性為 input_arg 和 output_arg,他們定義了運算的輸出和輸入。因為輸入輸出都可以有多個,所以這兩個屬性都是列表。第四個屬性Attr給出了其他的運算參數信息。在JSON文件中共定義了七個運算,下面將給出比較有代表性的一個運算來輔助說明OpDef 的數據結構。

op {
    name: "Add"
    input_arg{
        name: "x"
        type_attr:"T"
    }
    input_arg{
        name: "y"
        type_attr:"T"
    }
    output_arg{
        name: "z"
        type_attr:"T"
    }
    attr{
        name:"T"
        type:"type"
        allow_values{
            list{
                type:DT_HALF
                type:DT_FLOAT
                ...
            }
        }
    }
}

  上面給出了名稱為Add的運算。這個運算有兩個輸入和一個輸出,輸入輸出屬性都指定了屬性 type_attr,並且這個屬性的值為 T。在OpDef的Attr屬性中,必須要出現名稱(name)為 T的屬性。以上樣例中,這個屬性指定了運算輸入輸出允許的參數類型(allowed_values)。

2,graph_def 屬性

  graph_def 屬性主要記錄了TensorFlow 計算圖上的節點信息。TensorFlow計算圖的每一個節點對應了TensorFlow程序中一個運算,因為在 meta_info_def 屬性中已經包含了所有運算的具體信息,所以 graph_def 屬性只關注運算的連接結構。graph_def屬性是通過 GraphDef Protocol Buffer 定義的,graph_def主要包含了一個 NodeDef類型的列表。一下代碼給出了 graph_def 和NodeDef類型中包含的信息:

message GraphDef{
    #GraphDef的主要信息存儲在node屬性中,他記錄了Tensorflow計算圖上所有的節點信息。
    repeated NodeDef node = 1;
    VersionDef versions = 4; #主要儲存了Tensorflow的版本號
};

message NodeDef{
    #NodeDef類型中有一個名稱屬性name,他是一個節點的唯一標識符,在程序中,通過節點的名稱來獲得相應的節點。
    string name = 1;

    '''
    op屬性給出了該節點使用的Tensorflow運算方法的名稱。
    通過這個名稱可以在TensorFlow計算圖元圖的meta_info_def屬性中找到該運算的具體信息。
    '''
    string op = 2;

    '''
    input屬性是一個字符串列表,他定義了運算的輸入。每個字符串的取值格式為弄的:src_output
    node部分給出節點名稱,src_output表明了這個輸入是指定節點的第幾個輸出。
    src_output=0時可以省略src_output部分
    '''
    repeated string input = 3;

    #制定了處理這個運算的設備,可以是本地或者遠程的CPU or GPU。屬性為空時自動選擇
    string device = 4;

    #制定了和當前運算有關的配置信息
    map<string, AttrValue> attr = 5;
};

  GraphDef中的versions屬性比較簡單,它主要存儲了TensorFlow的版本號。和其他屬性類似,NodeDef 類型中有一個名稱屬性 name,它是一個節點的唯一標識符,在TensorFlow程序中可以通過節點的名稱來獲取響應節點。 NodeDef 類型中 的 device屬性指定了處理這個運算的設備。運行TensorFlow運算的設備可以是本地機器的CPU或者GPU,當device屬性為空時,TensorFlow在運行時會自動選取一個最適合的設備來運行這個運算,最后NodeDef類型中的Attr屬性指定了和當前運算相關的配置信息。

  下面列舉了 model.ckpt.meta.json 文件中的一個計算節點來更加具體的了解graph_def屬性:

graph def {
    node {
        name: "v1"
        op: "Variable"
        attr {
            key:"_output_shapes"
            value {
                list{ shape { dim { size: 1 } } }
            }
        }
    }
    attr { 
        key :"dtype"
        value {
            type: DT_FLOAT
            }
        }           
        ...
    }
    node {
        name :"add"
        op :"Add"
        input :"v1/read" #read指讀取變量v1的值
        input: "v2/read"
        ...
    }
    node {
        name: "save/control_dependency" #指系統在完成tensorflow模型持久化過程中自動生成一個運算。
        op:"Identity"
        ...
    }
    versions {
        producer :24 #給出了文件使用時的Tensorflow版本號。
    }
}

  上面給出了 model.ckpt.meta.json文件中 graph_def 屬性里面比較有代表性的幾個節點。第一個節點給出的是變量定義的運算。在TensorFlow中變量定義也是一個運算,這個運算的名稱為 v1(name:),運算方法的名稱是Variable(op: "Variable")。定義變量的運算可以有很多個,於是在NodeDef類型的node屬性中可以有多個變量定義的節點。但是定義變量的運算方法只用到了一個,於是在MetaInfoDef類型的 stripped_op_list 屬性中只有一個名稱為Variable 的運算方法。除了制定計算圖中的節點的名稱和運算方法。NodeDef類型中還定義了運算相關的屬性。在節點 v1中,Attr屬性指定了這個變量的維度以及類型。

  給出的第二個節點是代表加法運算的節點。它指定了2個輸入,一個為 v1/read,另一個為 v2/read。其中 v1/read 代表的節點可以讀取變量 v1的值,因為 v1的值是節點 v1/read的第一個輸出,所以后面的:0就可以省略了。v2/read也類似的代表了變量v2的取值。以上樣例文件中給出的最后一個名稱為 save/control_dependency,該節點是系統在完成TensorFlow模型持久化過程中自動生成的一個運算。在樣例文件的最后,屬性versions給出了生成 model.ckpt.meta.json 文件時使用的TensorFlow版本號。

3,saver_def 屬性

  saver_def 屬性中記錄了持久化模型時需要用到的一些參數,比如保存到文件的文件名,保存操作和加載操作的名稱以及保存頻率,清理歷史記錄等。saver_def 屬性的類型為SaverDef,其定義如下:

message SaverDef {
    string filename_tensor_name = 1;
    string save_tensor_name = 2;
    string restore_op_name = 3;
    int32 max_to_keep = 4;
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    enum CheckpointFormatVersion {
        LEGACY = 0;
        V1 = 1;
        V2 = 2;
    }
    CheckpointFormatVersion version = 7;
}

  下面給出了JSON文件中 saver_def 屬性的內容:

saver_def {
  filename_tensor_name: "save/Const:0"
  save_tensor_name: "save/control_dependency:0"
  restore_op_name: "save/restore_all"
  max_to_keep: 5
  keep_checkpoint_every_n_hours: 10000.0
  version: V2
}

  filename_tensor_name 屬性給出了保存文件名的張量名稱,這個張量就是節點 save/Const的第一個輸出。save_tensor_name屬性給出了持久化TensorFlow模型的運算所對應的節點名稱。從上面的文件中可以看出,這個節點就是在 graph_def 屬性中給出的 save/control_dependency節點。和持久化TensorFlow模型運算對應的是加載TensorFlow模型的運算,這個運算的名稱是由 restore_op_name 屬性指定。max_to_keep 屬性和 keep_checkpoint_every_n_hours屬性設置了 tf.train.Saver 類清理之前保存的模型的策略。比如當 max_to_keep 為5的時候,在第六次調用 saver.save 時,第一次保存的模型就會被自動刪除,通過設置 keep_checkpoint_every_n_hours,每n小時可以在 max_to_keep 的基礎上多保存一個模型。

4,collection def 屬性

  在TensorFlow的計算圖(tf.Graph)中可以維護不同集合,而維護這些集合的底層實現就是通過collection_def 這個屬性。collection_def 屬性是一個從集合名稱到集合內容的映射,其中集合名稱為字符串,而集合內容為 CollectionDef Protocol Buffer。以下代碼給出了 CollectionDef類型的定義:

message CollectionDef {
    message Nodelist {
    #用於維護計算圖上的節點集合
        repeated string value = 1;
    }

    message BytesList {
    #維護字符串或者系列化之后的Procotol Buffer的集合。例如張量是通過Protocol Buffer表示的,而張量的集合是通過BytesList維護的。
        repeated bytes value = 1 ;
    }

    message Int64List {
        repeated int64 value = 1[packed = true];
    }
    message FloatList {
        repeated float value = 1[packed = true] ;
    }
    message AnyList {
        repeated google.protobuf.Any value= 1;
    }
    oneof kind {
        NodeList node_list = 1;
        BytesList bytes_lista = 2;
        Int64List int64_list = 3;
        Floatlist float_list = 4;
        AnyList any_list = 5;
    }
}

  通過上面的定義可以看出,TensorFlow計算圖上的集合主要可以維護四類不同的集合。NodeList用於維護計算圖上節點的集合。BytesList 可以維護字符串或者系列化之后 Procotol Buffer的集合。比如張量是通過Procotol Buffer表示的,而張量的集合是通過BytesList維護的,我們將在JSON文件中看到具體樣例。Int64List用於維護整數集合,FloatList用於維護實數集合。下面給出了JSON文件中collection_def 屬性的內容:

collection_def {
  key: "trainable_variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}
collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}

  從上面的文件可以看到樣例程序中維護了兩個集合。一個是所有變量的集合,這個集合的名稱是Variables。另外一個是可訓練變量的集合。名為 trainable_variables。在樣例程序中,這兩個集合中的元素是一樣的,都是變量 v1和 v2,他們是系統自動維護的。

  model.ckpt 文件中列表的第一行描述了文件的元信息,比如在這個文件中存儲的變量列表,列表剩下的每一行保存了一個變量的片段。變量片段的信息是通過SavedSlice Protocol Buffer 定義的。SavedSlice 類型中保存了變量的名稱,當前片段的信息以及變量取值。TensorFlow提供了  tf.train.NewCheckpointReader 類來查看 model.ckpt文件中保存的變量信息,下面代碼展示了如何使用tf.train.NewCheckpointReader 類:

#_*_coding:utf-8_*_
import tensorflow as tf

# tf.train.NewCheckpointReader()  可以讀取 checkpoint文件中保存的所有變量
reader = tf.train.NewCheckpointReader('model/model.ckpt')

# 獲取所有變量列表,這是一個從變量名到變量維度的字典
all_variables = reader.get_variable_to_shape_map()
for variable_name in all_variables:
    # variable_name 為變量名稱, all_variables[variable_name]為變量的維度
    print(variable_name, all_variables[variable_name])

#獲取名稱為v1 的變量的取值
print('Value for variable v1 is ', reader.get_tensor('v1'))
'''
v1 [1]     # 變量v1的維度為[1]
v2 [1]     # 變量v2的維度為[1]
Value for variable v1 is  [1.]   # 變量V1的取值為1
'''

  最后一個文件的名字是固定的,叫checkpoint。這個文件是 tf.train.Saver類自動生成且自動維護的。在 checkpoint 文件中維護了由一個 tf.train.Saver類持久化的所有 TensorFlow模型文件的文件名。當某個保存的TensorFlow模型文件被刪除的,這個模型所對應的文件名也會從checkpoint文件中刪除。checkpoint中內容格式為 CheckpointState Protocol Buffer,下面給出了 CheckpointState 類型的定義。

message CheckpointState {
    string model_checkpoint_path = 1,
    repeated string all_model_checkpoint_paths = 2;
}

  model_checkpoint_path 屬性保存了最新的TensorFlow模型文件的文件名。 all_model_checkpoint_paths 屬性列表了當前還沒有被刪除的所有TensorFlow模型文件的文件名。下面給出了生成的某個checkpoint文件:

model_checkpoint_path: "modeltest.ckpt"
all_model_checkpoint_paths: "modeltest.ckpt"

 

二,將CKPT轉化為pb格式

  很多時候,我們需要將TensorFlow的模型導出為單個文件(同時包含模型結構的定義與權重),方便在其他地方使用(如在Android中部署網絡)。利用 tf.train.write_graph() 默認情況下只能導出了網絡的定義(沒有權重),而利用 tf.train.Saver().save() 導出的文件 graph_def 與權重時分離的,因此需要采用別的方法。我們知道,graph_def 文件中沒有包含網絡中的 Variable值(通常情況存儲了權重),但是卻包含了constant 值,所以如果我們能把Variable 轉換為 constant,即可達到使用一個文件同時存儲網絡架構與權重的目標。

  (PS:利用tf.train.write_graph() 保存模型,該方法只是保存了模型的結構,並不保存訓練完畢的參數值。)

  TensorFlow 為我們提供了 convert_variables_to_constants() 方法,該方法可以固化模型結構,將計算圖中的變量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

  將CKPT轉換成 PB格式的文件的過程如下:

  • 1,通過傳入 CKPT模型的路徑得到模型的圖和變量數據
  • 2,通過 import_meta_graph 導入模型中的圖
  • 3,通過saver.restore 從模型中恢復圖中各個變量的數據
  • 4,通過 graph_util.convert_variables_to_constants 將模型持久化

  下面的CKPT 轉換成 PB格式例子,是之前訓練的GoogleNet InceptionV3模型保存的ckpt轉pb文件的例子:

#_*_coding:utf-8_*_
import tensorflow as tf
from tensorflow.python.framework import graph_util
from create_tf_record import *

resize_height = 224  # 指定圖片高度
resize_width = 224   # 指定圖片寬度

def freeze_graph(input_checkpoint, output_graph):
    '''

    :param input_checkpoint:
    :param output_graph:  PB 模型保存路徑
    :return:
    '''
    # 檢查目錄下ckpt文件狀態是否可用
    # checkpoint = tf.train.get_checkpoint_state(model_folder)
    # 得ckpt文件路徑
    # input_checkpoint = checkpoint.model_checkpoint_path

    # 指定輸出的節點名稱,該節點名稱必須是元模型中存在的節點
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 獲得默認的圖
    input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢復圖並得到數據
        # 模型持久化,將變量值固定
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            # 等於:sess.graph_def
            input_graph_def=input_graph_def,
            # 如果有多個輸出節點,以逗號隔開
            output_node_names=output_node_names.split(","))

        # 保存模型
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())  # 序列化輸出
        # 得到當前圖有幾個操作節點
        print("%d ops in the final graph." % len(output_graph_def.node))

        # for op in graph.get_operations():
        #     print(op.name, op.values())

  說明

  • 1,函數 freeze_graph中,最重要的就是要確定“指定輸出的節點名稱”,這個節點名稱必須是原模型中存在的節點,對於 freeze 操作,我們需要定義輸出節點的名字。因為網絡其實是比較復雜的,定義了輸出節點的名字,那么freeze操作的時候就只把輸出該節點所需要的子圖都固化下來,其他無關的就舍棄掉。因為我們 freeze 模型的目的是接下來做預測,所以 output_node_names 一般是網絡模型最后一層輸出的節點名稱,或者說我們預測的目標。
  • 2,在保存的時候,通過 convert_variables_to_constants 函數來指定需要固化的節點名稱,對於下面的代碼,需要固化的節點只有一個:output_node_names。注意節點名稱與張量名稱的區別。比如:“input:0  是張量的名稱”,而“input” 表示的是節點的名稱。
  • 3,源碼中通過 graph=tf.get_default_graph() 獲得默認的圖,這個圖就是由 saver=tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) 恢復的圖,因此必須先執行 tf.train.import_meta_graph,再執行 tf.get_default_graph()。
  • 4,實質上,我們可以直接在恢復的會話 sess 中,獲得默認的網絡圖,更簡單的方法,如下:
def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路徑
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑

    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        # 恢復圖並得到數據
        saver.restore(sess, input_checkpoint)
        # 模型持久化,將變量值固定
        output_graph_def = graph_util.convert_variables_to_constants(  
            sess=sess,
            input_graph_def=sess.graph_def,  # 等於:sess.graph_def
            # 如果有多個輸出節點,以逗號隔開
            output_node_names=output_node_names.split(","))

        # 保存模型
        with tf.gfile.GFile(output_graph, "wb") as f:
            # 序列化輸出
            f.write(output_graph_def.SerializeToString())
        # 得到當前圖有幾個操作節點
        print("%d ops in the final graph." % len(output_graph_def.node))  

  調用方法很簡單,輸入 ckpt 模型路徑,輸出 Pb模型的路徑即可:

# 輸入ckpt模型路徑
input_checkpoint='model/model.ckpt-10000'

# 輸出pb模型的路徑
out_pb_path="model/frozen_model.pb"

# 調用freeze_graph將ckpt轉為pb
freeze_graph(input_checkpoint,out_pb_path)  

  注意:在保存的時候,通過convert_variables_to_constants 函數來指定需要固化的節點名稱,對於上面的代碼,需要固化的節點只有一個 : output_nideo_names。因此,其他網絡模型,也可以通過簡單的修改輸出的節點名稱output_node_names將ckpt轉為pb文件。

  PS:注意節點名稱,應包含 name_scope 和 variable_scope命名空間,並用“/”隔開,如“InceptionV3/Logits/SpatialSqueeze”。

2.1 對指定輸出的節點名稱的理解

  如果說我們使用InceptionV3算法進行訓練,那么指定輸出的節點名稱如下:

# 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
output_node_names = "InceptionV3/Logits/SpatialSqueeze"

  那么為什么呢?

  我去查看了InceptionV3的源碼,首先模型的輸入名字叫做 InceptionV3;

  其次它要的是輸出的節點,我們看InceptionV3算法的輸出,也就是最后一層的源碼,部分源碼如下:

# Final pooling and prediction
with tf.variable_scope('Logits'):
  if global_pool:
    # Global average pooling.
    net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='GlobalPool')
    end_points['global_pool'] = net
  else:
    # Pooling with a fixed kernel size.
    kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8])
    net = slim.avg_pool2d(net, kernel_size, padding='VALID',
                          scope='AvgPool_1a_{}x{}'.format(*kernel_size))
    end_points['AvgPool_1a'] = net
  if not num_classes:
    return net, end_points
  # 1 x 1 x 2048
  net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
  end_points['PreLogits'] = net
  # 2048
  logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                       normalizer_fn=None, scope='Conv2d_1c_1x1')
  if spatial_squeeze:
    logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
  # 1000
end_points['Logits'] = logits
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')

  我們會發現最后一層的名字為  Logits,輸出的的name = 'SpatialSqueeze'。

  所以我的理解是指定輸出的節點名稱是模型在代碼中的名稱+最后一層的名稱+輸出節點的名稱。當然這里只有一個輸出。

  如果不知道網絡節點名稱,或者說不想去模型中找節點名稱,那么我們可以在加載完模型的圖數據之后,可以輸出圖中的節點信息查看一下模型的輸入輸出節點:

for op in tf.get_default_graph().get_operations():
    print(op.name, op.values())

  這樣就可以找出輸出節點名稱。那我也在考慮如果只輸出最后節點的名稱是否可行呢?

  我測試了名字改為下面幾種:

    # output_node_names = 'SpatialSqueeze'
    # output_node_names = 'MobilenetV1/SpatialSqueeze'
    output_node_names = 'MobilenetV1/Logits/SpatialSqueeze'

  也就是不添加模型名稱和最后一層的名稱,添加模型名稱不添加最后一層的名稱。均報錯:

AssertionError: MobilenetV1/SpatialSqueeze is not in graph

  所以這里還是乖乖使用全稱。

那最后輸出的節點名稱到底是什么呢?怎么樣可以直接高效的找出呢?

  首先呢,我個人認為,最后輸出的那一層,應該必須把節點名稱命名出來,另外怎么才能確定我們的圖結構里有這個節點呢?百度了一下,有人說可以在TensorBoard中查找到,TensorBoard只能在Linux中使用,在Windows中得到的TensorBoard查看不了,是亂碼文件,在Linux中就沒有問題。所以如果你的Windows可以查看,就不需要去Linux中跑了。

  查看TensorBoard

tensorboard --logdir = “保存tensorboard的絕對路徑”

  敲上面的命令,然后就可以得到一個網址,把這個網址復制到瀏覽器上打開,就可以得到圖結構,然后點開看看,有沒有output這個節點,也可以順便查看一下自己的網絡圖。但是這個方法我沒有嘗試。我繼續百度了一下,哈哈哈哈,查到了下面的方法。

  就是如果可以按照下面四步驟走的話基本就不需要上面那么麻煩了:

  首先在ckpt模型的輸入輸出張量名稱,然后將ckpt文件生成pb文件;再查看生成的pb文件的輸入輸出節點,運行pb文件,進行網絡預測。所以這里關注的重點就是如何查看ckpt網絡的輸入輸出張量名稱和如何查看生成的pb文件的輸入輸出節點。

2.2  查看ckpt網絡的輸入輸出張量名稱

  首先我們找到網絡訓練后生成的ckpt文件,運行下面代碼查看自己模型的輸入輸出張量名稱(用於保存pb文件時保留這兩個節點):

def check_out_pb_name(checkpoint_path):
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        res = reader.get_tensor(key)
        print('tensor_name: ', key)
        print('a.shape: %s'%[res.shape])

if __name__ == '__main__':
    # 輸入ckpt模型路徑
    checkpoint_path = 'modelsmobilenet/model.ckpt-100000'
    check_out_pb_name(checkpoint_path)

  這里我繼續使用自己用的mobilenetV1模型,運行后的代碼部分結果如下:

tensor_name:  MobilenetV1/Conv2d_10_depthwise/BatchNorm/gamma
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_0/weights
a.shape: [(3, 3, 3, 32)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_11_depthwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Adadelta_1
a.shape: [(3, 3, 256, 1)]
tensor_name:  MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/depthwise_weights/Adadelta
a.shape: [(3, 3, 256, 1)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/moving_variance
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_10_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_5_depthwise/BatchNorm/beta/Adadelta_1
a.shape: [(256,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/beta
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_8_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_10_depthwise/BatchNorm/moving_variance
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_7_depthwise/BatchNorm/beta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/beta/Adadelta
a.shape: [(32,)]
tensor_name:  MobilenetV1/Conv2d_0/BatchNorm/gamma
a.shape: [(32,)]

... ...

tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/beta/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_depthwise/depthwise_weights/Adadelta_1
a.shape: [(3, 3, 512, 1)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Adadelta
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/gamma/Adadelta_1
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/BatchNorm/moving_mean
a.shape: [(512,)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights/Adadelta
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Conv2d_9_pointwise/weights/Adadelta_1
a.shape: [(1, 1, 512, 512)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights
a.shape: [(1, 1, 1024, 51)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights/Adadelta_1
a.shape: [(1, 1, 1024, 51)]
tensor_name:  MobilenetV1/Logits/Conv2d_1c_1x1/weights/Adadelta
a.shape: [(1, 1, 1024, 51)]

  我的模型是使用TensorFlow官網中標准的MoiblenetV1模型,所以輸入輸出張量比較容易找到,那如果自己的模型比較復雜(或者說是別人重構的模型),那如何找呢?

  那找到模型的定義,然后在模型的最前端打印出輸入張量,在最后打印出輸出張量。

  注意上面雖然最后輸出的張量名稱為:MobilenetV1/Logits/Conv2d_1c_1x1,但是如果我們直接用這個,還是會報錯的,這是為什么呢?這就得去看模型文件,上面也有,這里再粘貼一下(還是利用MobilenetV1模型):

with tf.variable_scope(scope, 'MobilenetV1', [inputs], reuse=reuse) as scope:
  with slim.arg_scope([slim.batch_norm, slim.dropout],
                      is_training=is_training):
    net, end_points = mobilenet_v1_base(inputs, scope=scope,
                                        min_depth=min_depth,
                                        depth_multiplier=depth_multiplier,
                                        conv_defs=conv_defs)
    with tf.variable_scope('Logits'):
      if global_pool:
        # Global average pooling.
        net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
        end_points['global_pool'] = net
      else:
        # Pooling with a fixed kernel size.
        kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
        net = slim.avg_pool2d(net, kernel_size, padding='VALID',
                              scope='AvgPool_1a')
        end_points['AvgPool_1a'] = net
      if not num_classes:
        return net, end_points
      # 1 x 1 x 1024
      net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
      logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                           normalizer_fn=None, scope='Conv2d_1c_1x1')
      if spatial_squeeze:
        logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
    end_points['Logits'] = logits
    if prediction_fn:
      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')

  最后這里,他對Logits變量進行了刪除維度為1的過程。並且將名稱重命名為SpatialSqueeze,一般如果不進行這一步就沒問題。所以我們如果出問題了,就對模型進行查看,當然第二個方法是可行的。

 

2.3  查看生成的pb文件的輸入輸出節點

  查看pb文件的節點,只是為了驗證一下,當然也可以不查看,直接去上面拿到結果即可,就是輸出節點的名稱。

def create_graph(out_pb_path):
    # 讀取並創建一個圖graph來存放訓練好的模型
    with tf.gfile.FastGFile(out_pb_path, 'rb') as f:
        # 使用tf.GraphDef() 定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

def check_pb_out_name(out_pb_path, result_file):
    create_graph(out_pb_path)
    tensor_name_list = [tensor.name for tensor in
                        tf.get_default_graph().as_graph_def().node]
    with open(result_file, 'w+') as f:
        for tensor_name in tensor_name_list:
            f.write(tensor_name+'\n')

  我們運行后,查看對應的TXT文件,可以看到,輸入輸出的節點和前面是對應的:

  或者這樣查看:

pb_path=["./test.pb",'./my_frozen_mobilenet_v1.pb'][1]
with tf.Session() as sess:
    with open(pb_path, 'rb') as f:
        graph_def = tf.GraphDef()

        print('>>>打印輸入節點的結構如下:\n', graph_def.node[0])

        print('>>>打印輸出節點的結構如下:\n',graph_def.node[-1])

  這樣就解決了這個問題,最后使用pb模型進行預測即可。下面是這兩個查找輸出節點的完整代碼:

# _*_coding:utf-8_*_
from tensorflow.python import pywrap_tensorflow
import os
import tensorflow as tf

def check_out_pb_name(checkpoint_path):
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        res = reader.get_tensor(key)
        print('tensor_name: ', key)
        print('res.shape: %s'%[res.shape])

def create_graph(out_pb_path):
    # 讀取並創建一個圖graph來存放訓練好的模型
    with tf.gfile.FastGFile(out_pb_path, 'rb') as f:
        # 使用tf.GraphDef() 定義一個空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

def check_pb_out_name(out_pb_path, result_file):
    create_graph(out_pb_path)
    tensor_name_list = [tensor.name for tensor in
                        tf.get_default_graph().as_graph_def().node]
    with open(result_file, 'w+') as f:
        for tensor_name in tensor_name_list:
            f.write(tensor_name+'\n')



if __name__ == '__main__':
    # 輸入ckpt模型路徑
    checkpoint_path = 'modelsmobilenet/model.ckpt-100000'
    check_out_pb_name(checkpoint_path)

    # 輸出pb模型的路徑
    out_pb_path = 'modelmobilenet.pb'
    result_file = 'mobilenet_graph.txt'
    check_pb_out_name(out_pb_path, result_file)

 

2.4  查看h5文件的輸入輸出節點

  如果訓練的模型是h5類型的,我們可以直接在h5中查看輸入輸出的節點名稱,然后在 Pb中使用。

  代碼如下:

def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
        http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
        '''
    margin = 1
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * square_pred + (1 - y_true) * margin_square)


def h5_to_pb(h5_file, output_dir, model_name, out_prefix="output_"):
    h5_model = load_model(h5_file, custom_objects={'contrastive_loss': contrastive_loss})
    print(h5_model.input)
    # [<tf.Tensor 'input_2:0' shape=(?, 80, 80) dtype=float32>, <tf.Tensor 'input_3:0' shape=(?, 80, 80) dtype=float32>]
    print(h5_model.output)  # [<tf.Tensor 'lambda_1/Sqrt:0' shape=(?, 1) dtype=float32>]
    print(len(h5_model.outputs))  # 1

 

 

三,使用pb模型預測

  下面是pb模型預測的代碼:

def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path: pb文件的路徑
    :param image_path: 測試圖片的路徑
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定義輸入的張量名稱,對應網絡結構的輸入張量
            # input:0作為輸入圖像,keep_prob:0作為dropout的參數,測試時值為1,is_training:0訓練參數
            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

            # 定義輸出的張量名稱
            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

            # 讀取測試圖片
            im = read_image(image_path, resize_height, resize_width, normalization=True)
            im = im[np.newaxis, :]
            # 測試讀出來的模型是否正確,注意這里傳入的是輸出和輸入節點的tensor的名字,不是操作節點的名字
            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
            out = sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                          input_keep_prob_tensor: 1.0,
                                                          input_is_training_tensor: False})
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print("pre class_id:{}".format(sess.run(class_id)))

  

3.1  說明

1,與ckpt預測不同的是,pb文件已經固化了網絡模型結構,因此,即使不知道原訓練模型(train)的源碼,我們也可以恢復網絡圖,並進行預測。恢復模型非常簡單,只需要從讀取的序列化數據中導入網絡結構即可:

tf.import_graph_def(output_graph_def, name="")

2,但是必須知道原網絡模型的輸入和輸出的節點名稱(當然了,傳遞數據時,是通過輸入輸出的張量來完成的)。由於InceptionV3模型的輸入有三個節點,因此這里需要定義輸入的張量名稱,它對應的網絡結構的輸入張量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")

input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")

input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

  以及輸出的張量名稱:

output_tensor_name = sess.graph.get_tensor_by_name(
                                        "InceptionV3/Logits/SpatialSqueeze:0")

  如何獲取輸入輸出張量的名稱,上面有寫,這里不再贅述。需要注意的是:使用pb獲取的張量名稱和使用 h5獲取的張量名稱有一個區別,如下:

pb 獲取的張量名稱是:
        ['input_2', 'input_3', 'output_1']

而 h5 獲取的張量名稱是:
        ['input_2:0',  'input_3:0',  'lambda_1/Sqrt:0']

需要注意的是,我們下面肯定需要使用張量名稱,而使用哪個呢?
如果使用Pb獲取的張量名稱,則會報下面的錯誤:
ValueError: The name 'conv2d_1_input' refers to an Operation,
 not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".


所以我們需要使用 h5 獲取的張量名稱,即格式為tensor格式:<op_name>:<output_index>

   雖然我們使用的是pb預測,但是格式必須是 tensor格式,不然會報錯。

3,預測時,需要 feed輸入數據

# 測試讀出來的模型是否正確
# 注意這里傳入的是輸出和輸入節點的tensor的名字,不是操作節點的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", 
                 feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4,其他網絡模型預測時,也可以通過修改輸入和輸出的張量的名稱。

(PS:注意張量的名稱,即為:節點名稱+ “:”+“id號”,如"InceptionV3/Logits/SpatialSqueeze:0")

   

  完整的CKPT轉換成PB格式和預測的代碼如下:

# _*_coding:utf-8_*_
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import cv2

'''
checkpoint文件是檢查點文件,文件保存了一個目錄下所有模型文件列表。
model.ckpt.data文件保存了TensorFlow程序中每一個變量的取值
model.ckpt.index文件則保存了TensorFlow程序中變量的索引
model.ckpt.meta文件則保存了TensorFlow計算圖的結構
'''


def freeze_graph(input_checkpoint, output_graph):
    '''
    指定輸出的節點名稱
    將模型文件和權重文件整合合並為一個文件
    :param input_checkpoint:
    :param output_graph: PB模型保存路徑
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder)
    # 檢查目錄下的ckpt文件狀態是否可以用
    # input_checkpoint = checkpoint.model_checkpoint_path  # 得ckpt文件路徑

    # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
    # PS:注意節點名稱,應包含name_scope 和 variable_scope命名空間,並用“/”隔開,
    output_node_names = 'MobilenetV1/Logits/SpatialSqueeze'
    # 首先通過下面函數恢復圖
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    # 然后通過下面函數獲得默認的圖
    graph = tf.get_default_graph()
    # 返回一個序列化的圖代表當前的圖
    input_graph_def = graph.as_graph_def()

    with tf.Session() as sess:
        # 加載已經保存的模型,恢復圖並得到數據
        saver.restore(sess, input_checkpoint)
        # 在保存的時候,通過下面函數來指定需要固化的節點名稱
        output_graph_def = graph_util.convert_variables_to_constants(
            # 模型持久化,將變量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等於:sess.graph_def
            # freeze模型的目的是接下來做預測,
            # 所以 output_node_names一般是網絡模型最后一層輸出的節點名稱,或者說我們預測的目標
            output_node_names=output_node_names.split(',')  # 如果有多個輸出節點,以逗號隔開
        )

        with tf.gfile.GFile(output_graph, 'wb') as f:  # 保存模型
            # 序列化輸出
            f.write(output_graph_def.SerializeToString())
        # # 得到當前圖有幾個操作節點
        print('%d ops in the final graph' % (len(output_graph_def.node)))

        # 這個可以得到各個節點的名稱,如果斷點調試到輸出結果,看看模型的返回數據
        # 大概就可以猜出輸入輸出的節點名稱
        for op in graph.get_operations():
            print(op.name)
            # print(op.name, op.values())


def read_image(filename, resize_height, resize_width, normalization=False):
    '''
    讀取圖片數據,默認返回的是uint8,[0,255]
    :param filename:
    :param resize_height:
    :param resize_width:
    :param normalization:是否歸一化到[0.,1.0]
    :return: 返回的圖片數據
    '''

    bgr_image = cv2.imread(filename)
    if len(bgr_image.shape) == 2:  # 若是灰度圖則轉為三通道
        print("Warning:gray image", filename)
        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)  # 將BGR轉為RGB
    # show_image(filename,rgb_image)
    # rgb_image=Image.open(filename)
    if resize_height > 0 and resize_width > 0:
        rgb_image = cv2.resize(rgb_image, (resize_width, resize_height))
    rgb_image = np.asanyarray(rgb_image)
    if normalization:
        # 不能寫成:rgb_image=rgb_image/255
        rgb_image = rgb_image / 255.0
    # show_image("src resize image",image)
    return rgb_image


def freeze_graph_test(pb_path, image_path):
    '''
    預測pb模型的代碼
    :param pb_path: pb文件的路徑
    :param image_path: 測試圖片的路徑
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, 'rb') as f:
            output_graph_def.ParseFromString(f.read())
            # 恢復模型,從讀取的序列化數據中導入網絡結構即可
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定義輸入的張量名稱,對應網絡結構的輸入張量
            # input: 0 作為輸入圖像,
            # keep_prob:0作為dropout的參數,測試時值為1,
            # is_training: 0 訓練參數
            input_image_tensor = sess.graph.get_tensor_by_name('input:0')
            input_keep_prob_tensor = sess.graph.get_tensor_by_name('keep_prob:0')
            input_is_training_tensor = sess.graph.get_tensor_by_name('is_training:0')

            # 定義輸出的張量名稱:注意為節點名稱 + “:”+id好
            name = 'MobilenetV1/Logits/SpatialSqueeze:0'
            output_tensor_name = sess.graph.get_tensor_by_name(name=name)

            # 讀取測試圖片
            im = read_image(image_path, resize_height, resize_width, normalization=True)
            im = im[np.newaxis, :]
            # 測試讀出來的模型是否正確,注意這里傳入的時輸出和輸入節點的tensor的名字,不是操作節點的名字
            out = sess.run(output_tensor_name, feed_dict={
                input_image_tensor: im,
                input_keep_prob_tensor: 1.0,
                input_is_training_tensor: False
            })
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print('Pre class_id:{}'.format(sess.run(class_id)))


if __name__ == '__main__':
    # 輸入ckpt模型路徑
    input_checkpoint = 'modelsmobilenet/model.ckpt-100000'
    # 輸出pb模型的路徑
    out_pb__path = 'modelmobilenet.pb'
    # 指定圖片的高度,寬度
    resize_height, resize_width = 224, 224
    depth = 3

    # 調用freeze_graph將ckpt轉pb
    # freeze_graph(input_checkpoint, out_pb__path)

    # 測試pb模型
    image_path = '5.png'
    freeze_graph_test(pb_path=out_pb__path, image_path=image_path)

  結果如下:

out:[[ -6.41409     -7.542293    -4.79263     -0.8360114   -5.9790826
    4.5435553   -0.36825374  -6.4866605   -2.4342375   -0.77123785
   -3.8730755   -2.9347122   -1.2668624   -2.0682898   -4.8219028
   -4.0054555   -4.929347    -4.3350396   -1.3294952   -5.2482243
   -5.6148944   -0.5199025   -2.8043954   -7.536846    -8.050901
   -5.4447656   -6.8323407   -6.221056    -8.040736    -7.3237658
  -10.494858    -9.077686    -6.8210897  -10.038142    -9.5562935
   -3.884094    -4.31197     -7.0326185   -2.3761833   -9.571469
    1.0321844   -9.319367    -5.5040984   -4.881267    -6.99698
   -9.591501    -8.059127    -7.494555   -10.593867    -6.862433
   -4.373736  ]]
Pre class_id:[5]

  我將測試圖片命名為5,就是與結果相對應,結果一致。表明使用pb預測出來了,並且預測正確。

  這里解釋一下,我是使用MobileNetV1模型進行訓練一個51個分類的數據,而拿到的第6個類的數據進行測試(我的標簽是從0開始的),這里測試正確。

四,h5轉pb,轉 tflite

  我們通常使用Keras訓練模型后,保存模型格式類型為 hdf5格式,也就是 .h5文件,但是我們如果想要移植到移動端,特別是基於 TensorFlow 支持的移動端,那就需要轉換成 tflite格式。

  這里廢話不多說,直接上轉換的代碼:

from keras.models import load_model
from tensorflow.python.framework import graph_util
from tensorflow import lite
from keras import backend as K
import os


def h5_to_pb(h5_file, output_dir, model_name, out_prefix="output_"):
    h5_model = load_model(h5_file, custom_objects={'contrastive_loss': contrastive_loss})
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.output[i], out_prefix + str(i + 1))
    sess = K.get_session()
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
    with tf.gfile.GFile(os.path.join(output_dir, model_name), "wb") as filemodel:
        filemodel.write(main_graph.SerializeToString())
    print("pb model: ", {os.path.join(output_dir, model_name)})


def pb_to_tflite(pb_file, tflite_file):
    inputs = ["input_1"]  # 模型文件的輸入節點名稱
    classes = ["output_1"]  # 模型文件的輸出節點名稱
    converter = tf.lite.TocoConverter.from_frozen_graph(pb_file, inputs, classes)
    tflite_model = converter.convert()
    with open(tflite_file, "wb") as f:
        f.write(tflite_model)


def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
        http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
        '''
    margin = 1
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * square_pred + (1 - y_true) * margin_square)


def h5_to_tflite(h5_file, tflite_file):
    converter = lite.TFLiteConverter.from_keras_model_file(h5_file,
                                                           custom_objects={'contrastive_loss': contrastive_loss})
    tflite_model = converter.convert()
    with open(tflite_file, 'wb') as f:
        f.write(tflite_model)


if __name__ == '__main__':
    h5_file = 'screw_10.h5'
    tflite_file = 'screw_10.tflite'
    pb_file = 'screw_10.pb'
    # h5_to_tflite(h5_file, tflite_file)
    # h5_to_pb(h5_file=h5_file, model_name=pb_file, output_dir='', )
    pb_to_tflite(pb_file, tflite_file)

   h5轉Pb后,使用pb模型預測,我這里寫一個孿生網絡的pb預測代碼(其他的照着改就行):

import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2


def predict_pb(pb_model_path, image_path1, image_path2, target_size):
    sess = tf.Session()
    with gfile.FastGFile(pb_model_path, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    # 輸入  這里有兩個輸入
    input_x = sess.graph.get_tensor_by_name('input_2:0')
    input_y = sess.graph.get_tensor_by_name('input_3:0')
    # 輸出
    op = sess.graph.get_tensor_by_name('lambda_1/Sqrt:0')

    image1 = cv2.imread(image_path1)
    image2 = cv2.imread(image_path2)
    # 灰度化,並調整尺寸
    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    image1 = cv2.resize(image1, target_size)
    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)
    image2 = cv2.resize(image2, target_size)
    data1 = np.array([image1], dtype='float') / 255.0
    data2 = np.array([image2], dtype='float') / 255.0
    y_pred = sess.run(op, {input_x: data1, input_y: data2})
    print(y_pred)

 

 

   此文是自己的學習筆記總結,學習於《TensorFlow深度學習框架》,俗話說,好記性不如爛筆頭,寫寫總是好的,所以若侵權,請聯系我,謝謝。

  其實網上有很多ckpt轉pb的文章,大多數來自下面的博客,我這里也只是做個筆記,記錄自己的學習過程,並且調試通代碼,方便自己使用。

還有參考文獻:https://blog.csdn.net/guyuealian/article/details/82218092

https://blog.csdn.net/weixin_42535742/article/details/93657397


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM