1. 背景
作為一名深度學習萌新,項目突然需要使用圖像分類模型去作分類,因此找到了TensorFlow的模型庫,使用它的框架進行訓練和后續的操作,項目地址:https://github.com/tensorflow/models/tree/master/research/slim。
在使用真正的數據集之前,我首先使用的是它提供的flowers的數據集,用的模型是inception_resnet_v2,因為top-5 Accuracy比較高嘛。
然后我安裝flowers的目錄結構,將我的數據按照類似的結構進行組織;
仿照download_and_convert_flowers.py增加了自己的數據處理文件convert_normal_data.py;
仿照數據集讀取文件flowers.py增加了自己的文件normal.py;
然后使用項目的教程,一步步的進行fine-tuning,直到准確率到了百分之九十以上,停止訓練。
但是這個時候在導出模型的時候遇到了坑。
2. 導出Inference Graph
實際上教程寫得很簡單,就是先導出模型的框架:
Saves out a GraphDef containing the architecture of the model.
然后再往框架里把訓練好的checkpoints寫到graph中:
If you then want to use the resulting model with your own or pretrained checkpoints as part of a mobile model, you can run freeze_graph to get a graph def with the variables inlined
它放出來的教程是這樣的:
$ python export_inference_graph.py \ --alsologtostderr \ --model_name=inception_v3 \ --output_file=/tmp/inception_v3_inf_graph.pb
我安裝這個格式去把模型改成inception_resnet_v2,然后把checkpoint導進去,總是會報:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [2]
[[{{node save/Assign_916}}]]
找了個群問了一下,說是模型最后一層輸出的數目沒有改變,於是重新理了思路,去看了export_inference_graph.py的源碼,發現里面有個num_classes的參數,是用來決定最后輸出層的數量的,於是最后增加了一下導出參數,最后的命令為:
python export_inference_graph.py \ --alsologtostderr \ --model_name=${MODEL_NAME} \ --dataset_name=normal \ --dataset_dir=${DATASET_DIR} \ --output_file=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb
最后獲得我的graph.pb。
3. 凍結Graph
凍結是個大坑,為什么呢,因為官方給出的教程是使用bazel先編譯freeze_graph,然后再使用它進行模型凍結。麻煩來了,首先Ubuntu 18.04無法使用apt進行安裝,所以一番折騰,使用它放出的install腳本進行了安裝。
然后是需要git clone TensorFlow的源碼進行編譯,這個編譯期間又報了很多錯,而且我編譯失敗后,conda環境的TensorFlow GPU版本還不能用了。。。
最后發現,如果你已經使用conda或者git安裝了TensorFlow,直接使用
find / -name freeze_graph.py
找出這個python文件的位置就行了,最后使用命令:
python tensorflow/python/tools/freeze_graph.py \ --input_graph=/you/path/to/sava/${MODEL_NAME}_inf_graph.pb \ --input_checkpoint=/you/trained/checkpoints/model.ckpt-10000 \ --input_binary=true \ --output_node_names=InceptionResnetV2/Logits/Predictions \ --output_graph=/your/path/to/save/frozen_graph.pb
最后終於導出了模型。
4. 使用模型進行預測
主要參考了博文【深度學習-模型eval+模型導出】使用Tensorflow Slim對訓練的模型進行評估+導出模型,進行微調:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import os.path import re import sys import tarfile import numpy as np from six.moves import urllib import tensorflow as tf FLAGS = None class NodeLookup(object): def __init__(self, label_lookup_path=None): self.node_lookup = self.load(label_lookup_path) def load(self, label_lookup_path): node_id_to_name = {} with open(label_lookup_path) as f: for line in f: line_list = line.strip().split(":") node_id_to_name[int(line_list[0])] = line_list[1] return node_id_to_name def id_to_string(self, node_id): if node_id not in self.node_lookup: return '' return self.node_lookup[node_id] def create_graph(): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. with tf.gfile.FastGFile(FLAGS.model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') def preprocess_for_eval(image, height, width, central_fraction=0.875, scope=None): with tf.name_scope(scope, 'eval_image', [image, height, width]): if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image, dtype=tf.float32) # Crop the central region of the image with an area containing 87.5% of # the original image. if central_fraction: image = tf.image.central_crop(image, central_fraction=central_fraction) if height and width: # Resize the image to the specified height and width. image = tf.expand_dims(image, 0) image = tf.image.resize_bilinear(image, [height, width], align_corners=False) image = tf.squeeze(image, [0]) image = tf.subtract(image, 0.5) image = tf.multiply(image, 2.0) return image def run_inference_on_image(image): """Runs inference on an image. Args: image: Image file name. Returns: Nothing """ with tf.Graph().as_default(): image_data = tf.gfile.FastGFile(image, 'rb').read() image_data = tf.image.decode_jpeg(image_data) image_data = preprocess_for_eval(image_data, 299, 299) image_data = tf.expand_dims(image_data, 0) with tf.Session() as sess: image_data = sess.run(image_data) # Creates graph from saved GraphDef. create_graph() with tf.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0') predictions = sess.run(softmax_tensor, {'input:0': image_data}) predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. node_lookup = NodeLookup(FLAGS.label_path) top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1] for node_id in top_k: human_string = node_lookup.id_to_string(node_id) score = predictions[node_id] print('%s (score = %.5f)' % (human_string, score)) def main(_): image = FLAGS.image_file run_inference_on_image(image) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--model_path', type=str, ) parser.add_argument( '--label_path', type=str, ) parser.add_argument( '--image_file', type=str, default='', help='Absolute path to image file.' ) parser.add_argument( '--num_top_predictions', type=int, default=5, help='Display this many predictions.' ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
最后使用一張圖片進行測試:
python classify_image_inception_resnet_v2.py \ --model_path /your/saved/path/frozen_graph.pb \ --label_path /your/path/labels.txt \ --image_file /your/path/test.jpg
最后輸出:
unsuited (score = 0.94713)
suited (score = 0.05287)
雖然有點高興,但是驀然回首,還是很心累,然后現在conda的TensorFlow GPU版本跪了,需要修復。
5. 參考
(1) 【深度學習-模型eval+模型導出】使用Tensorflow Slim對訓練的模型進行評估+導出模型
(2) 【Tensorflow系列】使用Inception_resnet_v2訓練自己的數據集並用Tensorboard監控
(完)