tf導出pb文件,以及如何使用pb文件


先羅列出來代碼,有時間再解釋

from tensorflow.python.framework import graph_util
import tensorflow as tf



def export_model(input_checkpoint, output_graph):
    #這個可以加載saver的模型
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() # 獲得默認的圖
    input_graph_def = graph.as_graph_def()  # 返回一個序列化的圖代表當前的圖

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,將變量值固定
        sess=sess,
        input_graph_def=input_graph_def,# 等於:sess.graph_def
        output_node_names=['softmax_linear/softmax_linear','Cast_1'])# 如果有多個輸出節點,以逗號隔開這個是重點,輸入和輸出的參數都需要在這里記錄

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        

export_model('E:\\python\\image\\code2\\model_10\\model.ckpt',"E:\\python\\image\\code2\\model_10\\model.pb")

使用的代碼

import os
import numpy as np
import tensorflow as tf
import model_new
from PIL import Image
import matplotlib.pyplot as plt
import csv
import shutil
from tensorflow.python.platform import gfile

def get_one_image(img_dir):
        
        image = Image.open(img_dir)
        
        image = image.resize((128,128))
        image = np.array(image)

        return image, img_dir


def test_model(model_path, img_path):
    image_array,img_dir = get_one_image( img_path)
    image = tf.cast(image_array,tf.float32)
    #image = tf.image.per_image_standardization(image)
    image = tf.reshape(image,[1,128,128,3])

    with tf.Session() as sess:
       
        with gfile.FastGFile(model_path,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def,name='')
        sess.run(tf.global_variables_initializer())
        
        input_x = sess.graph.get_tensor_by_name('Cast_1:0')
        out = sess.graph.get_tensor_by_name('softmax_linear/softmax_linear:0')
        ret = sess.run(out,  feed_dict={input_x: image.eval()})
        print(ret)
        



out_pb_path="E:\\python\\image\\code2\\model_10\\frozen_model.pb"
img_path = "E:\\python\\image\\code\\images\\0\\mmexport1540880139708.jpg"
test_model(out_pb_path,img_path)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM