Detectron2訓練visdrone記錄


准備

VOC標簽轉換參見這篇
注意:object_name = name_dict[box[4]] 改為 object_name = name_dict[box[5]]。為了與detectron2統一,
標簽文件夾命名為Annotations,圖片文件夾命名為JPEGImages,train.txt位於xxx/ImageSets/Main/。

train

構建instance

# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import numpy as np
import os
import xml.etree.ElementTree as ET
from fvcore.common.file_io import PathManager

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode

__all__ = ["register_visdrone_voc"]

CLASS_NAMES = ['__background__',  # always index 0
               'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']


def load_voc_instances(dirname: str, split: str):
    """
    Load Pascal VOC detection annotations to Detectron2 format.

    Args:
        dirname: Contain "Annotations", "ImageSets", "JPEGImages"
        split (str): one of "train", "test", "val", "trainval"
    """
    with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
        fileids = np.loadtxt(f, dtype=np.str)

    # Needs to read many small annotation files. Makes sense at local
    annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
    dicts = []
    for fileid in fileids:
        anno_file = os.path.join(annotation_dirname, fileid + ".xml")
        jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")

        with PathManager.open(anno_file) as f:
            tree = ET.parse(f)

        r = {
            "file_name": jpeg_file,
            "image_id": fileid,
            "height": int(tree.findall("./size/height")[0].text),
            "width": int(tree.findall("./size/width")[0].text),
        }
        instances = []

        for obj in tree.findall("object"):
            cls = obj.find("name").text
            # We include "difficult" samples in training.
            # Based on limited experiments, they don't hurt accuracy.
            # difficult = int(obj.find("difficult").text)
            # if difficult == 1:
            # continue
            bbox = obj.find("bndbox")
            bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
            # Original annotations are integers in the range [1, W or H]
            # Assuming they mean 1-based pixel indices (inclusive),
            # a box with annotation (xmin=1, xmax=W) covers the whole image.
            # In coordinate space this is represented by (xmin=0, xmax=W)
            bbox[0] -= 1.0
            bbox[1] -= 1.0
            instances.append(
                {"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
            )
        r["annotations"] = instances
        dicts.append(r)
    return dicts


def register_visdrone_voc(name, dirname, split, year):
    DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
    MetadataCatalog.get(name).set(
        thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
    )

train采用Faster R-CNN with FPN,backbone使用Resnext-101,群卷積32x8d,即32個group,每個group8個filter,
注意,如果因image圖片損壞無法訓練,修改PIL庫的ImageFile.py,將LOAD_TRUNCATED_IMAGES 改為 True

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultPredictor
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader
from visdrone_voc import *
import os
import cv2
import torch

register_visdrone_voc('VISDRONE_VOC', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-train'),
                      'train', 2012)
register_visdrone_voc('VISDRONE_VAL', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val'),
                      'val', 2012)
register_visdrone_voc('VISDRONE_TEST', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2019-DET-test-dev'),
                      'test', 2012)
cfg = get_cfg()
cfg.merge_from_file('configs/faster_rcnn_X_101_32x8d_FPN_3x.yaml')

# cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 4
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
#resume=True可繼續訓練並加載最新權重
trainer.resume_or_load(resume=False)
trainer.train()

#以下代碼可指定具體權重
#cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0229999.pth")
#checkpointer = DetectionCheckpointer(trainer.model)
#checkpointer.load(cfg.MODEL.WEIGHTS)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model

evaluator = PascalVOCDetectionEvaluator(cfg.DATASETS.TEST[0])
# val_loader = build_detection_test_loader(cfg, "VISDRONE_VAL")
# result_val = inference_on_dataset(trainer.model, val_loader, evaluator)
# print(result_val)
print(trainer.test(cfg, trainer.model, evaluator))

# predictor = DefaultPredictor(cfg)
# im = cv2.imread('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val/JPEGImages/0000026_03500_d_0000031.jpg')
# outputs = predictor(im)
# ooo = outputs['instances'].to(torch.device("cpu"))
# boxes = ooo.pred_boxes.tensor.numpy()
# print(boxes)
# for i in range(len(boxes)):
#     cv2.rectangle(im, tuple(boxes[i, 0:2]), tuple(boxes[i, 2:4]), (0, 255, 0), 2)
#
# cv2.imshow('visdrone', im)
# cv2.waitKey(0)

測試記錄

注意:數據集沒有__background__,計算AP是會除0,修改pascal_voc_evaluation.py Line 90,

            for cls_id, cls_name in enumerate(self._class_names):
+++             if cls_id == 0:  # __background__
+++                 continue
                lines = predictions.get(cls_id, [""])
AP AP50 AP75 iter datasets
20.8397 39.0423 19.9213 94999 val
17.0480 32.8580 16.1755 94999 test
22.2992 40.0259 21.7078 169999 val
18.0168 33.5292 17.6131 169999 test
22.9258 41.0643 22.4904 214999 val
18.1249 33.7228 17.6461 214999 test
22.8556 40.9861 22.3667 269999 val
18.0256 33.5866 17.5259 269999 test

可以看出效果遠高於yolo,最終配置和權重下載,提取碼: 74s4


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM