tensorflow三種加載模型的方法和三種模型保存文件(.ckpt,.pb, SavedModel)


一、.ckpt文件的保存和加載

1、保存的文件

 

 這是我保存的文件,保存一次有四個文件:

checkpoint文件:用於告知某些TF函數,這是最新的檢查點文件(可以用記事本打開看一下)

.data文件:(后面綴的那一串我也布吉島是啥)這個文件保存的是圖中所有變量的值,沒有結構。

.index文件:可能是保存了一些必要的索引叭(這個文件不大清楚)。

.meta文件:保存了計算圖的結構,但是不包含里面變量的值。

使用這種方法保存模型時會保存成上面這四個文件,重新加載模型時通常只會用到.meta文件恢復圖結構然后用.data文件把各個變量的值再加進去。

2、保存模型的方法

代碼:

saver=tf.train.Saver(max_to_keep)

saver.save(sess,'D:/model',global_step=epoch)

創建一個saver(max_to_keep可設置要保存的模型的個數),調用save方法將當前sess會話中的圖和變量等信息保存到指定路徑,global_step代表當前的輪數,設置之后會在文件名后面綴一個‘-600’這樣的東西

3、重加載模型的方法

saver=tf.train.import_meta_graph('model1/my-model-190.meta')  #恢復計算圖結構

saver.restore(sess, tf.train.latest_checkpoint("model/"))  #恢復所有變量信息

現在sess中已經恢復了網絡結構和變量信息了,接下來可以直接用節點的名稱來調用:

print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})

或者采用:

graph = tf.get_default_graph()

input_x = graph.get_tensor_by_name('x:0')

input_y=graph.get_tensor_by_name('y:0')

op=graph.get_tensor_name('op:0')

print(sess.run(op,feed_dict={input_x:2,input_y:3)

這樣子使用也可。

4、PS

.ckpt方式保存模型,這種模型文件是依賴 TensorFlow 的,只能在其框架下使用。

https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/這篇文章詳細解釋和演示了利用ckpt文件保存模型,並進行遷移學習的方法(不過是英文版的)

二、.pb文件的保存和加載

1、保存的文件

 .pb文件里面保存了圖結構+數據,加載模型時只需要這一個文件就好。

2、保存模型的方法

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op'])

with tf.gfile.FastGFile('D:/pycharm files/model.pb', mode='wb') as f:

  f.write(constant_graph.SerializeToString())

3、加載模型的方法

with tf.gfile.FastGFile(pb_file_path, 'rb') as f:

  graph_def = tf.GraphDef() # 生成圖

  graph_def.ParseFromString(f.read()) # 圖加載模型

   tf.import_graph_def(graph_def, name='')

接下來與前面的相同可以直接用節點的名稱來調用:

print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})

或者采用:

graph = tf.get_default_graph()

input_x = graph.get_tensor_by_name('x:0')

input_y=graph.get_tensor_by_name('y:0')

op=graph.get_tensor_name('op:0')

print(sess.run(op,feed_dict={input_x:2,input_y:3)

這樣子使用也可。

4、

谷歌推薦的保存模型的方式是保存模型為 PB 文件,它具有語言獨立性,可獨立運行,封閉的序列化格式,任何語言都可以解析它,它允許其他語言和深度學習框架讀取、繼續訓練和遷移 TensorFlow 的模型。另外的好處是保存為 PB 文件時候,模型的變量都會變成固定的,導致模型的大小會大大減小。

加載一個pb文件之后再對其進行微調(也就是將這個pb文件的網絡作為自己網絡的一部分),然后再保存成pb文件,后一個pb網絡會包含前一個pb網絡。

三、saved model

1、保存文件

在傳入的目錄下會有一個pb文件和一個文件夾:

2、保存模型

builder = tf.saved_model.builder.SavedModelBuilder(path)

builder.add_meta_graph_and_variables(sess,['cpu_server_1'])

3、加載模型

with tf.Session(graph=tf.Graph()) as sess:

  tf.saved_model.loader.load(sess, ['cpu_server_1'], pb_file_path+'savemodel')

接下來可以直接使用名字或者get_tensor_by_name后再進行使用

  input_x = sess.graph.get_tensor_by_name('x:0')

  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op:0')

  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})

關於savedmodel文件,可閱讀這篇博客,講的很清楚https://blog.csdn.net/thriving_fcl/article/details/75213361

 

下面代碼是實現上面三種保存模式的小例子,可以粘貼復制把相關的代碼注釋掉,運行一下看看結果,能加深理解:

import tensorflow as tf
with tf.Session() as sess:
  #搭建網絡
  x=tf.placeholder(tf.float32,name='x')
  y=tf.placeholder(tf.float32,name='y')
  b=tf.Variable(1.,name='b')
  xy=tf.multiply(x,y)
  op=tf.add(xy,b,name='op')
  sess.run(tf.global_variables_initializer())
  print(sess.run(op,feed_dict={x:2,y:3}))

  #ckpt保存
  saver=tf.train.Saver()
  saver.save(sess,'D:/pycharm files/111/ckpt/model_ck')

  #pb保存
  constant_graph=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,['op'])
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','wb') as f:
  f.write(constant_graph.SerializeToString())

  #savedmodel文件保存
  builder=tf.saved_model.builder.SavedModelBuilder('D:/pycharm files/111/savemodel')
  builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
  builder.save()

  print('over')


  #ckpt加載
  saver=tf.train.import_meta_graph('D:/pycharm files/111/ckpt/model_ck.meta')
  saver.restore(sess,tf.train.latest_checkpoint('D:/pycharm files/111/ckpt'))

  #pb加載
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','rb') as f:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

  #savemodel加載
  tf.saved_model.loader.load(sess, ['cpu_server_1'], 'D:/pycharm files/111/savemodel')

  #測試模型加載是否成功
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  op = sess.graph.get_tensor_by_name('op:0')
  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
  print(ret)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM