tensorflow 导入gfile模型文件


with tf.gfile.GFile(os.path.join(self.model_dir, 'ner_model.pb'), 'rb') as f:
      graph_def = self.tf.GraphDef()
      graph_def.ParseFromString(f.read())
      input_map = {"input_ids:0": self.input_ids,
                             'input_mask:0': self.input_mask}
       # 这就是我们要获取的op
      self.pred_ids = self.tf.import_graph_def(graph_def,
                                                         name='',
                                                         input_map=input_map,
                                                         return_elements=['pred_ids:0'])[0]
      graph = self.pred_ids.graph

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM