【AI小疑问】h5转换pb模型代码(解决:找不到输入到ReadVariableOp的变量)Cannot find the variable that is an input to the ReadVariableOp.


有人说要 更换Keras版本降级,感觉或许有用,但是为了统一版本又找到了一个解决方案,姑且不只还有什么坑

import keras.backend as K
k.set_learning_phase(0)

解决方案是将学习阶段设置为测试模式。

调用方式

如何使用
Keras模型可以使用该功能保存为单个[ .hdf5或h5]文件,该文件同时存储体系结构和权重model.save()。然后可以通过如下调用此工具将该模型转换为TensorFlow模型:

python keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --output_model="path/to/save/model.pb"
Keras模型也可以保存在两个单独的文件中,其中[ .hdf5或h5]文件使用该model.save_weights()功能存储权重,另一个.json文件使用该model.to_json()功能存储网络架构。在这种情况下,可以按以下方式转换模型:

python keras_to_tensorflow.py 
    --input_model="path/to/keras/model.h5" 
    --input_model_json="path/to/keras/model.json" 
    --output_model="path/to/save/model.pb"
尝试

python keras_to_tensorflow.py --help
了解其他受支持的标志(量化,output_nodes_prefix,save_graph_def)。

  

 

这是修改完的最终代码(从某github上download)  

  1 # -*- encoding: utf-8 -*-
  2 """
  3 @File    : ker_to_pb.py
  4 @Time    : 2020/5/12 20:23
  5 @Author  : xuluda
  6 @Email   : xuluda@163.com
  7 @Software: PyCharm
  8 
  9 """
 10 
 11 import tensorflow as tf
 12 from tensorflow.python.framework import graph_util
 13 from tensorflow.python.framework import graph_io
 14 from pathlib import Path
 15 from absl import app
 16 from absl import flags
 17 from absl import logging
 18 import keras
 19 from keras import backend as K
 20 from keras.models import model_from_json, model_from_yaml
 21 
 22 K.set_learning_phase(0)
 23 FLAGS = flags.FLAGS
 24 
 25 flags.DEFINE_string('input_model', None, 'Path to the input model.')
 26 flags.DEFINE_string('input_model_json', None, 'Path to the input model '
 27                                               'architecture in json format.')
 28 flags.DEFINE_string('input_model_yaml', None, 'Path to the input model '
 29                                               'architecture in yaml format.')
 30 flags.DEFINE_string('output_model', None, 'Path where the converted model will '
 31                                           'be stored.')
 32 flags.DEFINE_boolean('save_graph_def', False,
 33                      'Whether to save the graphdef.pbtxt file which contains '
 34                      'the graph definition in ASCII format.')
 35 flags.DEFINE_string('output_nodes_prefix', None,
 36                     'If set, the output nodes will be renamed to '
 37                     '`output_nodes_prefix`+i, where `i` will numerate the '
 38                     'number of of output nodes of the network.')
 39 flags.DEFINE_boolean('quantize', False,
 40                      'If set, the resultant TensorFlow graph weights will be '
 41                      'converted from float into eight-bit equivalents. See '
 42                      'documentation here: '
 43                      'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms')
 44 flags.DEFINE_boolean('channels_first', False,
 45                      'Whether channels are the first dimension of a tensor. '
 46                      'The default is TensorFlow behaviour where channels are '
 47                      'the last dimension.')
 48 flags.DEFINE_boolean('output_meta_ckpt', False,
 49                      'If set to True, exports the model as .meta, .index, and '
 50                      '.data files, with a checkpoint file. These can be later '
 51                      'loaded in TensorFlow to continue training.')
 52 
 53 flags.mark_flag_as_required('input_model')
 54 flags.mark_flag_as_required('output_model')
 55 
 56 
 57 def load_model(input_model_path, input_json_path=None, input_yaml_path=None):
 58     if not Path(input_model_path).exists():
 59         raise FileNotFoundError(
 60             'Model file `{}` does not exist.'.format(input_model_path))
 61     try:
 62         model = keras.models.load_model(input_model_path)
 63         return model
 64     except FileNotFoundError as err:
 65         logging.error('Input mode file (%s) does not exist.', FLAGS.input_model)
 66         raise err
 67     except ValueError as wrong_file_err:
 68         if input_json_path:
 69             if not Path(input_json_path).exists():
 70                 raise FileNotFoundError(
 71                     'Model description json file `{}` does not exist.'.format(
 72                         input_json_path))
 73             try:
 74                 model = model_from_json(open(str(input_json_path)).read())
 75                 model.load_weights(input_model_path)
 76                 return model
 77             except Exception as err:
 78                 logging.error("Couldn't load model from json.")
 79                 raise err
 80         elif input_yaml_path:
 81             if not Path(input_yaml_path).exists():
 82                 raise FileNotFoundError(
 83                     'Model description yaml file `{}` does not exist.'.format(
 84                         input_yaml_path))
 85             try:
 86                 model = model_from_yaml(open(str(input_yaml_path)).read())
 87                 model.load_weights(input_model_path)
 88                 return model
 89             except Exception as err:
 90                 logging.error("Couldn't load model from yaml.")
 91                 raise err
 92         else:
 93             logging.error(
 94                 'Input file specified only holds the weights, and not '
 95                 'the model definition. Save the model using '
 96                 'model.save(filename.h5) which will contain the network '
 97                 'architecture as well as its weights. '
 98                 'If the model is saved using the '
 99                 'model.save_weights(filename) function, either '
100                 'input_model_json or input_model_yaml flags should be set to '
101                 'to import the network architecture prior to loading the '
102                 'weights. \n'
103                 'Check the keras documentation for more details '
104                 '(https://keras.io/getting-started/faq/)')
105             raise wrong_file_err
106 
107 
108 def main(args):
109     # If output_model path is relative and in cwd, make it absolute from root
110     output_model = FLAGS.output_model
111     if str(Path(output_model).parent) == '.':
112         output_model = str((Path.cwd() / output_model))
113 
114     output_fld = Path(output_model).parent
115     output_model_name = Path(output_model).name
116     output_model_stem = Path(output_model).stem
117     output_model_pbtxt_name = output_model_stem + '.pbtxt'
118 
119     # Create output directory if it does not exist
120     Path(output_model).parent.mkdir(parents=True, exist_ok=True)
121 
122     if FLAGS.channels_first:
123         K.set_image_data_format('channels_first')
124     else:
125         K.set_image_data_format('channels_last')
126 
127     model = load_model(FLAGS.input_model, FLAGS.input_model_json, FLAGS.input_model_yaml)
128 
129     # TODO(amirabdi): Support networks with multiple inputs
130     orig_output_node_names = [node.op.name for node in model.outputs]
131     if FLAGS.output_nodes_prefix:
132         num_output = len(orig_output_node_names)
133         pred = [None] * num_output
134         converted_output_node_names = [None] * num_output
135 
136         # Create dummy tf nodes to rename output
137         for i in range(num_output):
138             converted_output_node_names[i] = '{}{}'.format(
139                 FLAGS.output_nodes_prefix, i)
140             pred[i] = tf.identity(model.outputs[i],
141                                   name=converted_output_node_names[i])
142     else:
143         converted_output_node_names = orig_output_node_names
144     logging.info('Converted output node names are: %s',
145                  str(converted_output_node_names))
146 
147     sess = K.get_session()
148     if FLAGS.output_meta_ckpt:
149         saver = tf.train.Saver()
150         saver.save(sess, str(output_fld / output_model_stem))
151 
152     if FLAGS.save_graph_def:
153         tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld),
154                              output_model_pbtxt_name, as_text=True)
155         logging.info('Saved the graph definition in ascii format at %s',
156                      str(Path(output_fld) / output_model_pbtxt_name))
157 
158     if FLAGS.quantize:
159         from tensorflow.tools.graph_transforms import TransformGraph
160         transforms = ["quantize_weights", "quantize_nodes"]
161         transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
162                                                converted_output_node_names,
163                                                transforms)
164         constant_graph = graph_util.convert_variables_to_constants(
165             sess,
166             transformed_graph_def,
167             converted_output_node_names)
168     else:
169         constant_graph = graph_util.convert_variables_to_constants(
170             sess,
171             sess.graph.as_graph_def(),
172             converted_output_node_names)
173 
174     graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
175                          as_text=False)
176     logging.info('Saved the freezed graph at %s',
177                  str(Path(output_fld) / output_model_name))
178 
179 
180 if __name__ == "__main__":
181     app.run(main)

 

 

 

最终代码放一下(以下是第一次搞得代码 虽然能用 但是最后在openvino 上转换的时候 说我的pb 冻结的不完整)

# -*- encoding: utf-8 -*-
"""
@File    : h5_to_pb.py
@Time    : 2020/5/11 17:54
@Author  : xuluda
@Email   : xuluda@163.com
@Software: PyCharm

将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend


backend.set_learning_phase(0)


# from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
    """.h5模型文件转换成pb模型文件
    Argument:
        h5_model: str
            .h5模型文件
        output_dir: str
            pb模型文件保存路径
        model_name: str
            pb模型文件名称
        out_prefix: str
            根据训练,需要修改
        log_tensorboard: bool
            是否生成日志文件
    Return:
        pb模型文件
    """
    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.output[i], out_prefix + str(i + 1))
    sess = backend.get_session()

    from tensorflow.python.framework import graph_util, graph_io
    # 写入pb模型文件
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
    # 输出日志文件
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)


if __name__ == '__main__':
    #  .h模型文件路径参数
    input_path = '/Users/jack/Documents/DaiCode/T4/'
    weight_file = 'false_positive3.h5'
    weight_file_path = os.path.join(input_path, weight_file)
    output_graph_name = weight_file[:-3] + '.pb'

    #  pb模型文件输出输出路径
    output_dir = osp.join(os.getcwd(), "trans_model")
    # model.save(xingren.h5)
    #  加载模型
    # h5_model = Sequential()
    h5_model = load_model(weight_file_path)
    # h5_model.save(weight_file_path)
    # h5_model.save('xingren.h5')
    h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
    print('Finished')

  


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM