Keras + Flask 提供接口服務的坑~~~


最近在搞Keras,訓練完的模型要提供個預測服務出來。就想了個辦法,通過Flask提供一個http服務,后來發現也能正常跑,但是每次預測都需要加載模型,效率非常低。

然后就把模型加載到全局,每次要用的時候去拿來用就行了,可是每次去拿的時候,都會報錯.

如:

ValueError: Tensor Tensor(**************) is not an element of this graph.

這個問題就是在你做預測的時候,他加載的圖,不是你第一次初始化模型時候的圖,所以圖里面沒有模型里的那些參數和節點

在網上找了個靠譜的解決方案,親測有效,原文:https://wolfx.cn/flask-keras-server/

 

解決方式如下:

When you create a Model, the session hasn't been restored yet. All placeholders, variables and ops that are defined in Model.init are placed in a new graph, which makes itself a default graph inside with block. This is the key line:

with tf.Graph().as_default():
  ...

This means that this instance of tf.Graph() equals to tf.get_default_graph() instance inside with block, but not before or after it. From this moment on, there exist two different graphs.

When you later create a session and restore a graph into it, you can't access the previous instance of tf.Graph() in that session. Here's a short example:

with tf.Graph().as_default() as graph:
  var = tf.get_variable("var", shape=[3], initializer=tf.zeros_initializer)

This works

with tf.Session(graph=graph) as sess:
  sess.run(tf.global_variables_initializer())
  print(sess.run(var))  # ok because `sess.graph == graph`

This fails

saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")
  print(sess.run(var))   # var is from `graph`, not `sess.graph`!

The best way to deal with this is give names to all nodes, e.g. 'input', 'target', etc, save the model and then look up the nodes in the restored graph by name, something like this:

saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")      
  input_data = sess.graph.get_tensor_by_name('input')
  target = sess.graph.get_tensor_by_name('target')

This method guarantees that all nodes will be from the graph in session.

Try to start with:

import tensorflow as tf
global graph,model
graph = tf.get_default_graph()

When you need to use predict:

with graph.as_default(): y = model.predict(X)

------------------------------------------------------------華麗的分割線------------------------------------------------------------

 

下面就上我自己的代碼,解釋一下如何使用:

我原來的代碼是這樣的:

  

 1 def get_angiogram_time(video_path):
 2     start = time.time()
 3     global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME
 4     if _MODEL_MA == None:
 5         model_ma = ma_ocr.Training_Predict()
 6         model_time = time_ocr.Training_Predict()
 7 
 8         model_ma.build_model()
 9         model_time.build_model()
10 
11         model_ma.load_model("./model/ma_gur_ctc_model.h5base")
12         model_time.load_model("./model/time_gur_ctc_model.h5base")
13 
14         _MODEL_MA = model_ma
15         _MODEL_TIME = model_time
16 
17     indexes = _MODEL_MA.predict(video_path)
18     time_dict = _MODEL_TIME.predict(video_path,indexes)
19     end = time.time()
20     print("耗時:%.2f s" % (end-start))
21     return json.dumps(time_dict)
 1     def predict(self, video_path):
 2         start = time.time()
 3 
 4         vid = cv2.VideoCapture(video_path)
 5         if not vid.isOpened():
 6             raise IOError("Couldn't open webcam or video")
 7         # video_fps = vid.get(cv2.CAP_PROP_FPS)
 8 
 9         X = self.load_video_data(vid)
10         y_pred = self.base_model.predict(X)
11         shape = y_pred[:, :, :].shape  # 2:
12         out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
13               :seq_len]  # 2:
14         print()

 

當實行到第10行 :y_pred = self.base_model.predict(X)

就會拋錯:Cannot use the given session to evaluate tensor: the tensor's graph is different from the session's graph.

 

大致意思就是:當前session里的圖和模型中的圖的各種參數不匹配

 

修改后代碼:

 

 1 def get_angiogram_time(video_path):
 2     start = time.time()
 3     global _MODEL_MA,_MODEL_TIME,_GRAPH_MA, _GRAPH_TIME
 4     if _MODEL_MA == None:
 5         model_ma = ma_ocr.Training_Predict()
 6         model_time = time_ocr.Training_Predict()
 7 
 8         model_ma.build_model()
 9         model_time.build_model()
10 
11         model_ma.load_model("./model/ma_gur_ctc_model.h5base")
12         model_time.load_model("./model/time_gur_ctc_model.h5base")
13 
14         _MODEL_MA = model_ma
15         _MODEL_TIME = model_time
16         _GRAPH_MA = tf.get_default_graph()
17         _GRAPH_TIME = tf.get_default_graph()
18 
19     with _GRAPH_MA.as_default():
20         indexes = _MODEL_MA.predict(video_path)
21     with _GRAPH_TIME.as_default():
22         time_dict = _MODEL_TIME.predict(video_path,indexes)
23     end = time.time()
24     print("耗時:%.2f s" % (end-start))
25     return json.dumps(time_dict)

主要修改在第16,17,19,21行

定義了一個全局的圖,每次都用這個圖

 

 

完美解決~

 

PS:問了一下專門做AI的朋友,他們公司是用TensorFlow Server提供對外服務的,我最近也要研究一下Tensorflow Server,本人是個AI小白,剛剛入門,寫的不對還請指正,謝謝!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM