Keras模型的保存方式


Keras模型的保存方式

在運行並且訓練出一個模型后獲得了模型的結構與許多參數,為了防止再次訓練以及需要更好地去使用,我們需要保存當前狀態

基本保存方式 h5

# 此處假設model為一個已經訓練好的模型類
model.save('my_model.h5')

轉換為json格式存儲基本參數

# 此處假設model為一個已經訓練好的模型類
json_string = model.to_json()
open('my_model_architecture.json','w').write(json_string)

轉換為二進制pb格式

以下代碼為我從網絡中尋找到的,可以將模型中的內容轉換為pb格式,但需要更改其中的h5為你的模型的h5

import sys \\
from keras.models import load_model \\
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
        output_names = output_names or [] 
        output_names += [v.op.name for v in tf.global_variables()] 
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,output_names,freeze_var_names)
    return frozen_graph

input_fld = sys.path[0] 
weight_file = 'my_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
    os.mkdir(output_fld) 
    weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0) 
net_model = load_model(weight_file_path) 
print('input is :', net_model.input.name) 
print ('output is:', net_model.output.name) 
sess = K.get_session() 
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name]) 
from tensorflow.python.framework import graph_io 
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False) 
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM