tensorflow學習筆記1:導出和加載模型


用一個非常簡單的例子學習導出和加載模型;

導出

寫一個y=a*x+b的運算,然后保存graph;

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

with tf.Session() as sess:
    a = tf.Variable(5.0, name='a')
    x = tf.Variable(6.0, name='x')
    b = tf.Variable(3.0, name='b')
    y = tf.add(tf.multiply(a,x),b, name="y")

    tf.global_variables_initializer().run()
    
    print (a.eval()) # 5.0
    print (x.eval()) # 6.0
    print (b.eval()) # 3.0
    print (y.eval()) # 33.0

    graph = convert_variables_to_constants(sess, sess.graph_def, ["y"])
    #writer = tf.summary.FileWriter("logs/", graph)
    tf.train.write_graph(graph, 'models/', 'test_graph.pb', as_text=False)

 

運行

在models目錄下生成了test_graph.pb;

注:convert_variables_to_constants操作是將模型參數froze(保存)進graph中,這時的graph相當於是sess.graph_def + checkpoint,即有模型結構也有模型參數;

 

加載

 只加載,獲取各個變量的值

import tensorflow as tf
from tensorflow.python.platform import gfile

with gfile.FastGFile("models/test_graph.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    output = tf.import_graph_def(graph_def, return_elements=['a:0', 'x:0', 'b:0','y:0'])
    #print(output)
    
with tf.Session() as sess:
    result = sess.run(output)
    print (result)

  

運行看以看到原本保存的結果(因為幾個變量都已經帶入模型,又從模型中加載了出來)

 

加載的時候修改變量值

5*2+3=13,結果正確

 

運行時修改變量值

加載時用一個占位符替掉x常量,在session運行時再給占位符填值;

5*3+3=18,也正確

 

修改計算結果

偷偷把結果給改了會怎么樣?

呵呵,不知原因為何;以后鑽進代碼了再說;

 

參考:

https://www.sohu.com/a/233679628_468681

http://blog.163.com/wujiaxing009@126/blog/static/7198839920174125748893/

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM