Tensorflow 使用slim框架下的分類模型進行分類


Tensorflow的slim框架可以寫出像keras一樣簡單的代碼來實現網絡結構(雖然現在keras也已經集成在tf.contrib中了),而且models/slim提供了類似之前說過的object detection接口類似的image classification接口,可以很方便的進行fine-tuning利用自己的數據集訓練自己所需的模型。

官方文檔提供了比較詳細的從數據准備,預訓練模型的model zoo,fine-tuning,freeze model等一系列流程的步驟,但是缺少了inference的文檔,不過tf所有模型的加載方式是通用的,所以調用方法和調用其他pb模型是一樣的。

根據TF開發人員是說法Tensorflow對於模型讀寫的保存和調用的步驟一般如下:Build your graph --> write your graph --> import from written graph --> run compute etc

以下我們使用slim提供的網絡inception-resnet-v2作為例子:

1. export inference graph

import tensorflow as tf
import nets.inception_resnet_v2 as net

slim = tf.contrib.slim

# checkpoint path
checkpoint_path = "/your/path/to/inception_resnet_v2.ckpt" # ckpt file obtained during model training or fine-tuning

# set up and load session
sess = tf.Session()
arg_scope = net.inception_resnet_v2_arg_scope()
# initialize tensor suitable for model input
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
    logits, end_points = net.inception_resnet_v2(inputs=input_tensor)

# set up model saver
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with tf.gfile.GFile('/your/path/to/model_graph.pb', 'w') as f:   # save model to given pb file
    f.write(sess.graph_def.SerializeToString()) 
f.close()

2. freeze model

這里用tf提供的tensorflow/python/tools下的freeze_graph工具:

$ bazel build tensorflow/python/tools:freeze_graph
$ bazel-bin/tensorflow/python/tools/freeze_graph \
    --input_graph=/your/path/to/model_graph.pb  \  # obtained above
    --input_checkpoint=/your/path/to/inception_resnet_v2.ckpt \
    --input_binary=true
    --output_graph=/your/path/to/frozen_graph.pb \
    --output_node_names=InceptionResnetV2/Logits/Predictions   # output node name defined in inception resnet v2 net

(Optional) visualize frozen graph

LOG_DIR = ‘/tmp/graphdeflogdir’
model_filename = '/your/path/to/frozen_graph.pb'

with tf.Session() as sess:
    with tf.gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
    writer = tf.summary.FileWriter(LOG_DIR, graph_def)
writer.close()

然后用tensorborad --logdir=LOG_DIR選擇graph就可以查看到frozen后的網絡結構。

3. inference

import cv2
import numpy as np

def preprocess_inception(image_np, central_fraction=0.875): 
    image_height, image_width, image_channel = image_np.shape
    if central_fraction:
        bbox_start_h = int(image_height * (1 - central_fraction) / 2)
        bbox_end_h = int(image_height - bbox_start_h)
        bbox_start_w = int(image_width * (1 - central_fraction) / 2)
        bbox_end_w = int(image_width - bbox_start_w)
        image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w]
    # normalize
    image_np = 2 * (image_np / 255.) - 1
    return image_np

image_np = cv2.imread("test.jpg")
# preprocess image as inception resnet v2 does
image_np = preprcess_inception(image_np)
# resize to model input image size
image_np = cv2.resize(image_np, (299, 299))
# expand dims to shape [None, 299, 299, 3]
image_np = np.expand_dims(image_np, 0)
# load model
with tf.gfile.GFile('/your/path/to/frozen_graph.pb')
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    with tf.Session(graph=graph) as sess:
        input tensor = sess.graph.get_tensor_by_name("input:0")   # get input tensor 
        output_tensor = sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0")  # get output tensor
        logits = sess.run(output_tensor, feed_dict={input_tensor: image_np})
        print "Prediciton label index:", np.argmax(logits[0], 1)
        print "Top 3 Prediciton label index:", np.argsort(logits[0], 3)

參考:

  1. https://stackoverflow.com/questions/42961243/using-pre-trained-inception-v4-model
  2. https://gist.github.com/cchadowitz-pf/f1c3e781c125813f9976f6e69c06fec2
  3. https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
  4. https://github.com/tensorflow/models/blob/master/slim/README.md
  5. https://gist.github.com/tokestermw/795cc1fd6d0c9069b20204cbd133e36b


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM