使用Tensorflow訓練自己的數據


訓練自己的數據集(以bottle為例):

 

1、准備數據

文件夾結構:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 訓練文件夾
└── val_logs 日志文件夾

 

1)、下載官方預訓練模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
ssd_mobilenet_v1_coco為例,將壓縮包內model.ckpt*的三個文件復制到bottle內

2)、准備jpg圖片數據,放入images文件夾(圖片文件命名要求“名字+下划線+編號.jpg”,必須使用下划線,編號從1開始) 
使用https://github.com/tzutalin/labelImg工具對圖片進行標注,生成xml文件放置xmls文件夾,並保持xml和jpg命名相同 
3)、新建 bottle/trainval.txt 文件,內容為(圖片名 1 1 1),每行一個文件,如:

bottle_1 1 1 1
bottle_2 1 1 1

 

4)、新建object_detection/data/bottle_label_map.pbtxt,內容如下

item {
    id: 1
    name: 'bottle'
}
 

2、生成數據

# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`

 

得到 pet_train.record 和 pet_val.record 移動至bottle文件夾

 

3、准備conf文件

復制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
對ssd_mobilenet_v1_bottle.config文件進行一下修改:

修改第9行為 num_classes: 1,此數值代表bottle_label_map.pbtxt文件配置item的數量
修改第158行為 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行為 input_path: "bottle/pet_train.record"
修改第179行和193行為 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行為 input_path: "bottle/pet_val.record"

 

 

4、訓練

# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &

 

 

5、驗證

# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &

 

 

6、可視化log

可一邊訓練一邊可視化訓練的log,可看到Loss趨勢。

tensorboard --logdir train_logs/

 

瀏覽器訪問 ip:6006,可看到趨勢以及具體image的預測結果

 

7、導出模型

# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle

 

生成 bottle/frozen_inference_graph.pb 文件

 

8、測試圖片

運行object_detection_tutorial.ipynb並修改其中的各種路徑即可 
或自寫編譯inference腳本,如tensorflow/models/object_detection/infer.py:

import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)

 

運行infer.py test_images/image1.jpg即可


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM