Tensorflow物體檢測(Object Detection)API的使用


Tensorflow在更新1.2版本之后多了很多新功能,其中放出了很多用tf框架寫的深度網絡結構(看這里),大大降低了吾等調包俠的開發難度,無論是fine-tuning還是該網絡結構都方便了不少。這里講的的是物體檢測(object detection)API,這個庫的說明文檔很詳細,可以的話直接看原文即可。

這個物體檢測API提供了5種網絡結構的預訓練的weights,全部是用COCO數據集進行訓練,可以在這里下載:分別是SSD+mobilenet, SSD+inception_v2, R-FCN+resnet101, faster RCNN+resnet101, faster RCNN+inception+resnet101。各個模型的精度和計算所需時間如下,具體測評細節可以看這篇文章

依賴包

Protobuf 2.6
Pillow 1.0
lxml
tf Slim 
Jupyter notebook
Matplotlib  # 用這個畫圖會比較慢,內存占用高,可以用cv2來代替
Tensorflow

API安裝

$ pip install tensorflow-gpu
$ sudo apt-get install protobuf-compiler python-pil python-lxml
$ sudo pip install jupyter
$ sudo pip install matplotlib

因為使用protobuf來配置模型和訓練參數,所以API正常使用必須先編譯protobuf庫

$ cd tensorflow/models
$ protoc object_detection/protos/*.proto --python_out=.

然后將models和slim(tf高級框架)加入python環境變量:

export PYTHONPATH=$PYTHONPATH:/your/path/to/tensorflow/models:/your/path/to/tensorflow/models/slim

最后測試安裝:

python object_detection/builders/model_builder_test.py

fine-tuning

  1. 准備數據集
    以Pascal VOC數據集的格式為例:object_detection/create_pascal_tf_record.py提供了一個模板,將voc格式的數據保存到.record格式
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \   # 訓練物品的品類和id
    --data_dir=VOCdevkit --year=VOC2012 --set=train \
    --output_path=pascal_train.record
python object_detection/create_pascal_tf_record.py \
    --label_map_path=object_detection/data/pascal_label_map.pbtxt \
    --data_dir=VOCdevkit --year=VOC2012 --set=val \
    --output_path=pascal_val.record

其中--data_dir為訓練集的目錄。結構同Pascal VOC,如下:

    + VOCdevkit  # +為文件夾
        + JPEGImages
            - 001.jpg  # - 為文件
        + Annotations
            - 001.xml
  1. 訓練
    train和eval輸入輸出數據儲存結構為:
    + input
        - label_map.pbtxt file  # 可以在object_detection/data/*.pbtxt找到樣例
        - train TFRecord file
        - eval TFRecord file
    + models
        + modelA
            - pipeline config file # 可以在object_detection/samples/configs/*.config下找到樣例,定義訓練參數和輸入數據
            + train  # 保存訓練產生的checkpoint文件
            + eval

准備好上述文件后就可以直接調用train文件進行訓練

python object_detection/train.py \
    --logtostderr \
    --pipeline_config_path=/your/path/to/models/modelA/pipeline config file \ 
    --train_dir=/your/path/to/models/modelA/train
  1. 評估
    在訓練開始以后,就可以運行eval來評估模型的效果。不過實際情況是eval模型也需要加載ckpt文件,因此也需要占用不小的顯存,而一般訓練的時候都會調整batch盡量利用顯卡性能,所以想要實時運行train和eval的話需要調整好兩者所需的內存。
python object_detection/eval.py \
    --logtostderr \
    --pipeline_config_path=/your/path/to/models/modelA/pipeline config file \
    --checkpoint_dir=/your/path/to/models/modelA/train \
    --eval_dir=/your/path/to/models/modelA/eval
  1. 監控
    通過tensorboard命令可以在瀏覽器很輕松的監控訓練進程,在瀏覽器輸入localhost:6006(默認)即可
tensorboard --logdir=/your/path/to/models/modelA  # 需要包含eval和train目錄(.ckpt, .index, .meta, checkpoint, graph.pbtxt文件)

freeze model

在訓練完成后需要將訓練產生的最后一組.meta, .index, .ckpt, checkpoint文件。其中meta保存了graph和metadata,ckpt保存了網絡的weights。而在生產環境中進行預測的時候是只需要模型和權重,不需要metadata,所以需要將其提出進行freeze操作,將所需的部分放到一個文件,方便之后的調用,也減少模型加載所需的內存。(在下載的預訓練模型解壓后可以找到4個文件,其中名為frozen_inference_graph.pb的文件就是freeze后產生的模型文件,比weights文件大,但是比weights和meta文件加起來要小不少。)

本來,tensorflow/python/tools/freeze_graph.py提供了freeze model的api,但是需要提供輸出的final node names(一般是softmax之類的最后一層的激活函數命名),而object detection api提供提供了預訓練好的網絡,final node name並不好找,所以object_detection目錄下還提供了export_inference_graph.py

python export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path /your/path/to/models/modelA/pipeline config file \
        --checkpoint_path  /your/path/to/models/modelA/train/model.ckpt-* \
        --inference_graph_path /your/path/to/models/modelA/train/frozen_inference_graph.pb  # 輸出的文件名

模型調用

目錄下提供了一個樣例。這里只是稍作調整用cv2來顯示圖像。

import numpy as np
import os, sys
import tensorflow as tf
import cv2

MODEL_ROOT = "/home/arkenstone/tensorflow/workspace/models"
sys.path.append(MODEL_ROOT)  # 應用和訓練的目錄在不同的地方

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

MODEL_PATH = "/home/arkenstone/tensorflow/workspace/models/objectdetection/models/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017"
PATH_TO_CKPT = MODEL_PATH + '/frozen_inference_graph.pb'  # frozen model path
PATH_TO_LABELS = os.path.join(MODEL_ROOT, 'object_detection/data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)  # 格式為{1:{'id': 1, 'name': 'person'}, 2: {'id': 2, 'name': 'bicycle'}, ...}

# 模型加載:test.py
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

# 防止內存不足,限制sess內存使用比例
gpu_memory_fraction = 0.4
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
config = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True)
config.gpu_options.allow_growth = False

def detect(image_path):
    with detection_graph.as_default():  # 需要手動close sess
      with tf.Session(graph=detection_graph, config=config) as sess:
          image = cv2.imread(image_path)
          image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
          image_np_expanded = np.expand_dims(image_np, axis=0)
          image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
          boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
          scores = detection_graph.get_tensor_by_name('detection_scores:0')
          classes = detection_graph.get_tensor_by_name('detection_classes:0')
          num_detections = detection_graph.get_tensor_by_name('num_detections:0')
          (boxes, scores, classes, num_detections) = sess.run(
              [boxes, scores, classes, num_detections],
              feed_dict={image_tensor: image_np_expanded})
          vis_util.visualize_boxes_and_labels_on_image_array(
              image_np,
              np.squeeze(boxes),
              np.squeeze(classes).astype(np.int32),
              np.squeeze(scores),
              category_index,
              use_normalized_coordinates=True,
              line_thickness=4)
          new_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
          cv2.imshow("test", new_img)
          cv2.waitKey(0)

if __name__ == '__main__':
    detect(/your/test/image)

參考

https://github.com/tensorflow/models/tree/master/object_detection
https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
https://www.tensorflow.org/extend/tool_developers/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM