【Demo 1】基於object_detection API的行人檢測 3:模型訓練與測試


訓練准備

模型選擇

選擇ssd_mobilenet_v2_coco模型,下載地址(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md),解壓到./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29.

修改object_detection配置文件

進入目錄./Pedestrian_Detection/models/research/object_detection/samples/configs 找到對應的模型配置文件ssd_mobilenet_v2_coco.config修改配置文件。

根據提示信息:

1、第9行,檢測類別把90改為1,因為我們只檢測行人,只有一個類別。

2、修改除提示外所有的

  2.1、第一個(156行)是我們所需模型的路徑,即上一步下載好的:./Pedestrian_Detection/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt

  2.2、第二個(175行)是train.record文件的路徑,上一篇中我們准備好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_train.record

  2.3、第三個(177行)是上一篇准備好的label_map.pbtxt的路徑:./Pedestrian_Detection/project/pedestrian_train/data/label_map.pbtxt

  2.4、第四個(189行)是eval.record文件的路徑,上一篇中我們准備好的record文件:./Pedestrian_Detection/project/pedestrian_train/data/pascal_eval.record

  2.5、第五個(191行)同2.3

這樣config文件就修改完成了。然后把它放到:./Pedestrian_Detection/project/pedestrian_train/models目錄下。最后在該目錄下創建兩個文件夾:train 和 eval,用於儲存訓練和驗證的記錄。

開始訓練

打開命令行窗口

在research目錄下輸入:

(dl) D:\Study\dl\Pedestrian_Detection\models\research>python object_detection/legacy/train.py --train_dir=D:\Study\dl\Pedestrian_Detection\project\pedestrian_train\models\train --pipeline_config_path=D:\Study\dl\Pedestrian_Detection\project\pedestrian_train\models\ssd_mobilenet_v2_coco.config --logtostderr

即可開始訓練。

這里我們選擇2000次之后,按ctrl+c結束訓練。訓練的詳細信息可通過tensorboard來進行查看(這里不再贅述)。

查看我們的訓練記錄:

導出模型

這里我們選擇第2391次的訓練數據來生成模型。

把下圖4個文件放到:./Pedestrian_Detection/pedestrian_data/model  目錄下

 

 

 

 

在命令行窗口下輸入命令:

(dl) D:\Study\dl\Pedestrian_Detection\models\research>python object_detection/export_inference_graph.py --input_type=image_tensor --pipeline_config_path=D:\Study\dl\Pedestrian_Detection\project\pedestrian_train\models\ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix=D:\Study\dl\Pedestrian_Detection\pedestrian_data\model\model.ckpt-2391 --output_directory=D:\Study\dl\Pedestrian_Detection\pedestrian_data\test

查看發現對應的目錄下已經生成了一系列的模型文件:

 

測試模型

測試代碼:

 1 import os
 2 import sys
 3 
 4 import cv2
 5 import numpy as np
 6 import tensorflow as tf
 7 
 8 sys.path.append("..")
 9 from object_detection.utils import label_map_util
10 from object_detection.utils import visualization_utils as vis_util
11 
12 ##################################################
13 
14 ##################################################
15 
16 # Path to frozen detection graph
17 PATH_TO_CKPT = 'D:/Study/dl/Pedestrian_Detection/pedestrian_data/test/frozen_inference_graph.pb'
18 
19 # List of the strings that is used to add correct label for each box.
20 PATH_TO_LABELS = os.path.join('D:/Study/dl/Pedestrian_Detection/project/pedestrian_train/data', 'label_map.pbtxt')
21 
22 NUM_CLASSES = 1
23 detection_graph = tf.Graph()
24 with detection_graph.as_default():
25     od_graph_def = tf.GraphDef()
26     with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
27         serialized_graph = fid.read()
28         od_graph_def.ParseFromString(serialized_graph)
29         tf.import_graph_def(od_graph_def, name='')
30 
31 label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
32 categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
33 category_index = label_map_util.create_category_index(categories)
34 
35 
36 def load_image_into_numpy_array(image):
37     (im_width, im_height) = image.size
38     return np.array(image.getdata()).reshape(
39       (im_height, im_width, 3)).astype(np.uint8)
40 
41 
42 with detection_graph.as_default():
43     with tf.Session(graph=detection_graph) as sess:
44         image_np = cv2.imread("D:/Study/dl/Pedestrian_Detection/project/test_images/3600.jpg")
45         # image_np = cv2.imread("D:/images/pedestrain.png")
46         cv2.imshow("input", image_np)
47         print(image_np.shape)
48         # image_np == [1, None, None, 3]
49         image_np_expanded = np.expand_dims(image_np, axis=0)
50         image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
51         boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
52         scores = detection_graph.get_tensor_by_name('detection_scores:0')
53         classes = detection_graph.get_tensor_by_name('detection_classes:0')
54         num_detections = detection_graph.get_tensor_by_name('num_detections:0')
55         # Actual detection.
56         (boxes, scores, classes, num_detections) = sess.run(
57             [boxes, scores, classes, num_detections],
58             feed_dict={image_tensor: image_np_expanded})
59         # Visualization of the results of a detection.
60         vis_util.visualize_boxes_and_labels_on_image_array(
61               image_np,
62               np.squeeze(boxes),
63               np.squeeze(classes).astype(np.int32),
64               np.squeeze(scores),
65               category_index,
66               use_normalized_coordinates=True,
67               min_score_thresh=0.5,
68               line_thickness=1)
69         cv2.imshow('object detection', image_np)
70         cv2.imwrite("D:/run_result.png", image_np)
71         cv2.waitKey(0)
72         cv2.destroyAllWindows()

測試效果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM