1 如何使用pb文件保存和恢復模型進行遷移學習(學習Tensorflow 實戰google深度學習框架)


學習過程是Tensorflow 實戰google深度學習框架一書的第六章的遷移學習環節。

具體見我提出的問題:https://www.tensorflowers.cn/t/5314

參考https://blog.csdn.net/zhuiqiuk/article/details/53376283后,對代碼進行了修改。

問題的跟蹤情況記錄:

1 首先是保存模型:

import tensorflow as tf
from tensorflow.python.framework import graph_util
v1=tf.constant([10000.0],name='v1')
#v1 = tf.placeholder(tf.float32,shape=[1],name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op,{v1:[100]})
    print (sess.run(result,{v1:[1000]}))
    writer = tf.summary.FileWriter('./graphs/model_graph', sess.graph)
    writer.close()    
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def,['add'])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())

因為inception v3接受輸入的tensor是Decode/Content:0,是一個const類型,就是tf.constant類型,而一開始,我並不明白問題的所在,就將tf.placeholder改為了tf.constant,而實際上,兩個都可以。問題的本身不是出在這里,而是對書本有錯誤的理解。

書上因為獲取的是兩個return elements,會自動從列表中取出元素。

而我獲得的是一個retrun elelment,則只能返回一個列表。

2 使用並加載持久化模型,直接調用模型的訓練參數進行計算。

#這是我以前寫的程序,是錯誤的
"""
因為我以前寫的只是獲取一個值。
而現在修正的v1則是一個tensor。我們可以修正tensor的值。
所以,tensorflow 實戰google深度學習框架中有重大bug。
不懂的聯系我手機 18627711314 傑
"""
import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"  
    #model_filename = "inception_dec_2015/tensorflow_inception_graph.pb"  
    
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    """將模型的相關信息寫入文件,和利用tensorboard進行可視化
    f = open("xiaojie2.txt", "w")
    print ("xiaojie2\n",file = f)
    print (graph_def,file=f)
    f.close()
    writer = tf.summary.FileWriter('./graphs/model_graph2', graph_def)
    writer.close()
    """
    """④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""
    trainable_variables=tf.trainable_variables()
    variable_list_name = [c.name for c in tf.trainable_variables()]
    variable_list = sess.run(variable_list_name)
    for k,v in zip(variable_list_name,variable_list):
        print("variable name:",k)
        print("shape:",v.shape)
        #print(v) 
    """④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""

v1= tf.import_graph_def(graph_def, return_elements=["v1:0"]) print (v1) print (sess.run(v1)) v2= tf.import_graph_def(graph_def, return_elements=["v2:0"]) print (sess.run(v2)) result = tf.import_graph_def(graph_def, return_elements=["add:0"]) print (sess.run(result)) x=np.array([2000.0]) print (sess.run(result,feed_dict={v1: x}))

這些都是參照書上使用inception模型時的做法,我參照着自己寫了一個模型,但是有重大bug

運行結果總是提示:

因為總是無法用feed_dict傳入我想計算的輸入。我一開始因為是tf.placeholder的原因,就參照inception的改為tf.constant。但是還是不行。后來在網上看到別人加載pb文件的一段代碼https://blog.csdn.net/zhuiqiuk/article/details/53376283,重新對代碼進行了修正。如下:

import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    output_graph_path='Saved_model/combined_model.pb'
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
"""④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""
        trainable_variables=tf.trainable_variables()
        variable_list_name = [c.name for c in tf.trainable_variables()]
        variable_list = sess.run(variable_list_name)
        for k,v in zip(variable_list_name,variable_list):
            print("variable name:",k)
            print("shape:",v.shape)
            #print(v) 
"""④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""
        input_x = sess.graph.get_tensor_by_name("v1:0")
        result = tf.import_graph_def(graph_def, return_elements=["add:0"])
        print (input_x)
        x=np.array([2000.0])
        print (sess.run(result,feed_dict={input_x: x}))

問題的根本在於:

V1= tf.import_graph_def(graph_def, return_elements=["v1:0"])獲取的是

[<tf.Tensor 'import/v1:0' shape=(1,) dtype=float32>],是一個列表

而:input_x = sess.graph.get_tensor_by_name("v1:0")

獲取的是一個Tensor,即Tensor("v1:0", shape=(1,), dtype=float32)。

使用sess.run的時候,feed_dict要修正的是tensor,而不是一個list。因此,總會提出unhashable type:listd的報錯。

3 將原始錯誤程序的代碼改為:

v1= tf.import_graph_def(graph_def, return_elements=["v1:0"])
print (v1)
print (sess.run(v1))
v2= tf.import_graph_def(graph_def, return_elements=["v2:0"])
print (sess.run(v2))
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print (sess.run(result))
x=np.array([2000.0])

#print (sess.run(result,feed_dict={v1: x}))
print (sess.run(result,feed_dict={v1[0]: x}))

也可以正確運行

4 后來正確的程序還可以改為:

import tensorflow as tf
import numpy as np
from numpy.random import RandomState
from tensorflow.python.platform import gfile
output_graph_def = tf.GraphDef()
output_graph_path='Saved_model/combined_model.pb'
with open(output_graph_path, "rb") as f:
    output_graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
"""④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""
    trainable_variables=tf.trainable_variables()
    variable_list_name = [c.name for c in tf.trainable_variables()]
    variable_list = sess.run(variable_list_name)
    for k,v in zip(variable_list_name,variable_list):
        print("variable name:",k)
        print("shape:",v.shape)
        #print(v) 
  """④輸出所有可訓練的變量名稱,也就是神經網絡的參數"""
    input_x = sess.graph.get_tensor_by_name("v1:0")
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print (input_x)
    x=np.array([2000.0])
    print (sess.run(result,feed_dict={input_x: x}))

需要注意的是:

首先,無論如何,加載pb以后,輸出所有可訓練的變量,都不可能輸出持久化模型中的變量。這一點以前就說過。以前說過,只能使用train.saver的方式。

其次,如果使用后一種方式,即sess.graph.get_tensor_by_name,則必須要有紅黃標注的那一幕。即:_ = tf.import_graph_def(output_graph_def, name="")

程序附件

鏈接:https://pan.baidu.com/s/11YtyDEyV84jONPi9tO2TCw 密碼:8mfj


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM