mmdetection訓練自己的模型【數據集轉變,數據集划分,數據集gt可視化,mmdetection配置文件生成及修改,開始訓練,gradio部署】


針對有一點mmdetction基礎的,然后想根據自己的數據集,熟練訓練自己的模型。需要改成自己配置的地方,我會在代碼中做好標記,方便修改。

我們先了解一下mmdetection的基本流程,你想訓練一個模型,你只需要准備的是:數據集mmdetection的配置文件

下面我分為兩部分,分別處理這兩個東西。然后你就可以用官方實現的訓練工具愉快的進行訓練了。

1. 數據集的處理

先把數據集復制到mmdetection的data目錄下,方便管理,data目錄下一個文件夾就是一個數據集。dataset1/data/目錄下是你的.xml文件和.jpg文件,如果你的數據集本身就是voc數據集,那可以跳過步驟1.1。

  • xml2voc2007.py:用於將.xml文件轉換成voc2007數據集。
  • voc2coco.py:用於將voc數據集轉換成coco數據集。
  • box_visiual.py:利用coco數據集可視化數據集的ground truth。查看數據集中是否有臟數據,根據具體情況除掉。

如果需要用到其他的格式轉換,或者數據集處理的一些操作,參考:數據集拆分,互轉,可視化,查錯 - 一屆書生 - 博客園 (cnblogs.com)

1.1 數據集轉變: .xml --> voc數據集

首先是數據集的處理,我是比較習慣用coco數據集,雖然mmdetection也可以訓練voc數據集。因為我拿到手的是一個.jpg和.xml文件的數據集,因為我們要先將.xml文件數據集轉換成voc數據集,然后再將voc數據集轉換成coco數據集。

mmdetection/data/dataset1/xml2voc2007.py

# 命令行執行:  python xml2voc2007.py --input_dir data --output_dir VOCdevkit
import argparse
import glob
import os
import os.path as osp
import random
import shutil
import sys

percent_train = 0.9  # 改成你想設置的訓練集比例。


def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--input_dir", default="data",
                        help="input annotated directory")  # 將保存你.jpg和.xml文件的文件夾名改為data,下邊就不用動了
    parser.add_argument("--output_dir", default="VOCdevkit", help="output dataset directory")  # 輸出的voc數據集目錄,不用動
    args = parser.parse_args()

    if osp.exists(args.output_dir):
        print("Output directory already exists:", args.output_dir)
        sys.exit(1)
    os.makedirs(args.output_dir)
    print("| Creating dataset dir:", osp.join(args.output_dir, "VOC2007"))

    # 創建保存的文件夾
    if not os.path.exists(osp.join(args.output_dir, "VOC2007", "Annotations")):
        os.makedirs(osp.join(args.output_dir, "VOC2007", "Annotations"))
    if not os.path.exists(osp.join(args.output_dir, "VOC2007", "ImageSets")):
        os.makedirs(osp.join(args.output_dir, "VOC2007", "ImageSets"))
    if not os.path.exists(osp.join(args.output_dir, "VOC2007", "ImageSets", "Main")):
        os.makedirs(osp.join(args.output_dir, "VOC2007", "ImageSets", "Main"))
    if not os.path.exists(osp.join(args.output_dir, "VOC2007", "JPEGImages")):
        os.makedirs(osp.join(args.output_dir, "VOC2007", "JPEGImages"))

    # 獲取目錄下所有的.jpg文件列表
    total_img = glob.glob(osp.join(args.input_dir, "*.jpg"))
    print('| Image number: ', len(total_img))

    # 獲取目錄下所有的joson文件列表
    total_xml = glob.glob(osp.join(args.input_dir, "*.xml"))
    print('| Xml number: ', len(total_xml))

    num_total = len(total_xml)
    data_list = range(num_total)

    num_tr = int(num_total * percent_train)
    num_train = random.sample(data_list, num_tr)

    print('| Train number: ', num_tr)
    print('| Val number: ', num_total - num_tr)

    file_train = open(
        osp.join(args.output_dir, "VOC2007", "ImageSets", "Main", "train.txt"), 'w')
    file_val = open(
        osp.join(args.output_dir, "VOC2007", "ImageSets", "Main", "val.txt"), 'w')

    for i in data_list:
        name = total_xml[i][:-4] + '\n'  # 去掉后綴'.jpg' 
        if i in num_train:
            file_train.write(name[5:])  # 因為這里的name是帶着目錄的,也就是name本來是:'data/1.jpg' ,去掉'data/' ,就是文件名了。
        else:
            file_val.write(name[5:])

    file_train.close()
    file_val.close()

    if os.path.exists(args.input_dir):
        # root 所指的是當前正在遍歷的這個文件夾的本身的地址
        # dirs 是一個 list,內容是該文件夾中所有的目錄的名字(不包括子目錄)
        # files 同樣是 list, 內容是該文件夾中所有的文件(不包括子目錄)
        for root, dirs, files in os.walk(args.input_dir):
            for file in files:
                src_file = osp.join(root, file)
                if src_file.endswith(".jpg"):
                    shutil.copy(src_file, osp.join(args.output_dir, "VOC2007", "JPEGImages"))
                else:
                    shutil.copy(src_file, osp.join(args.output_dir, "VOC2007", "Annotations"))
    print('| Done!')

if __name__ == "__main__":
    print("—" * 50)
    main()
    print("—" * 50)

1.2 數據集轉變: voc數據集 --> coco數據集

寫的有點繁瑣了,代碼比較冗長,暫時沒有時間去優化一下。 但是很好用!!!

mmdetection/data/dataset1/voc2coco.py

# -*- coding: utf-8 -*-
import json
import os
import shutil

root_path = os.getcwd()


def voc2coco():
    import datetime
    from PIL import Image

    # 處理coco數據集中category字段。
    # 創建一個 {類名 : id} 的字典,並保存到 總標簽data 字典中。
    class_name_to_id = {'point': 1, }

    # 創建coco的文件夾
    if not os.path.exists(os.path.join(root_path, "coco2017")):
        os.makedirs(os.path.join(root_path, "coco2017"))
        os.makedirs(os.path.join(root_path, "coco2017", "annotations"))
        os.makedirs(os.path.join(root_path, "coco2017", "train2017"))
        os.makedirs(os.path.join(root_path, "coco2017", "val2017"))

    # 創建 總標簽data
    now = datetime.datetime.now()
    data = dict(
        info=dict(
            description=None,
            url=None,
            version=None,
            year=now.year,
            contributor=None,
            date_created=now.strftime("%Y-%m-%d %H:%M:%S.%f"),
        ),
        licenses=[dict(url=None, id=0, name=None, )],
        images=[
            # license, file_name,url, height, width, date_captured, id
        ],
        type="instances",
        annotations=[
            # segmentation, area, iscrowd, image_id, bbox, category_id, id
        ],
        categories=[
            # supercategory, id, name
        ],
    )

    for name, id in class_name_to_id.items():
        data["categories"].append(
            dict(supercategory=None, id=id, name=name, )
        )

    # 處理coco數據集train中images字段。
    images_dir = os.path.join(root_path, 'VOCdevkit', 'VOC2007', 'JPEGImages')
    images = os.listdir(images_dir)

    # 生成每個圖片對應的image_id
    images_id = {}
    for idx, image_name in enumerate(images):
        images_id.update({image_name[:-4]: idx})

    # 獲取訓練圖片
    train_img = []
    fp = open(os.path.join(root_path, 'VOCdevkit', 'VOC2007', 'ImageSets', 'Main', 'train.txt'))
    for i in fp.readlines():
        train_img.append(i[:-1] + ".jpg")

    # 獲取訓練圖片的數據
    for image in train_img:
        img = Image.open(os.path.join(images_dir, image))
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=image,  # 圖片的文件名帶后綴
                height=img.height,
                width=img.width,
                date_captured=None,
                # id=image[:-4],
                id=images_id[image[:-4]],
            )
        )

    # 獲取coco數據集train中annotations字段。
    train_xml = [i[:-4] + '.xml' for i in train_img]

    bbox_id = 0
    for xml in train_xml:
        category = []
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        import xml.etree.ElementTree as ET
        tree = ET.parse(os.path.join(root_path, 'VOCdevkit', 'VOC2007', 'Annotations', xml))
        root = tree.getroot()
        object = root.findall('object')
        for i in object:
            category.append(class_name_to_id[i.findall('name')[0].text])
            bndbox = i.findall('bndbox')
            for j in bndbox:
                xmin.append(float(j.findall('xmin')[0].text))
                ymin.append(float(j.findall('ymin')[0].text))
                xmax.append(float(j.findall('xmax')[0].text))
                ymax.append(float(j.findall('ymax')[0].text))
        for i in range(len(category)):
            data["annotations"].append(
                dict(
                    id=bbox_id,
                    image_id=images_id[xml[:-4]],
                    category_id=category[i],
                    area=(xmax[i] - xmin[i]) * (ymax[i] - ymin[i]),
                    bbox=[xmin[i], ymin[i], xmax[i] - xmin[i], ymax[i] - ymin[i]],
                    iscrowd=0,
                )
            )
            bbox_id += 1
    # 生成訓練集的json
    json.dump(data, open(os.path.join(root_path, 'coco2017', 'annotations', 'instances_train2017.json'), 'w'))

    # 獲取驗證圖片
    val_img = []
    fp = open(os.path.join(root_path, 'VOCdevkit', 'VOC2007', 'ImageSets', 'Main', 'val.txt'))
    for i in fp.readlines():
        val_img.append(i[:-1] + ".jpg")

    # 將訓練的images和annotations清空,
    del data['images']
    data['images'] = []
    del data['annotations']
    data['annotations'] = []

    # 獲取驗證集圖片的數據
    for image in val_img:
        img = Image.open(os.path.join(images_dir, image))
        data["images"].append(
            dict(
                license=0,
                url=None,
                file_name=image,  # 圖片的文件名帶后綴
                height=img.height,
                width=img.width,
                date_captured=None,
                id=images_id[image[:-4]],
            )
        )

    # 處理coco數據集驗證集中annotations字段。
    val_xml = [i[:-4] + '.xml' for i in val_img]

    for xml in val_xml:
        category = []
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        import xml.etree.ElementTree as ET
        tree = ET.parse(os.path.join(root_path, 'VOCdevkit', 'VOC2007', 'Annotations', xml))
        root = tree.getroot()
        object = root.findall('object')
        for i in object:
            category.append(class_name_to_id[i.findall('name')[0].text])
            bndbox = i.findall('bndbox')
            for j in bndbox:
                xmin.append(float(j.findall('xmin')[0].text))
                ymin.append(float(j.findall('ymin')[0].text))
                xmax.append(float(j.findall('xmax')[0].text))
                ymax.append(float(j.findall('ymax')[0].text))
        for i in range(len(category)):
            data["annotations"].append(
                dict(
                    id=bbox_id,
                    image_id=images_id[xml[:-4]],
                    category_id=category[i],
                    area=(xmax[i] - xmin[i]) * (ymax[i] - ymin[i]),
                    bbox=[xmin[i], ymin[i], xmax[i] - xmin[i], ymax[i] - ymin[i]],
                    iscrowd=0,
                )
            )
            bbox_id += 1
    # 生成驗證集的json
    json.dump(data, open(os.path.join(root_path, 'coco2017', 'annotations', 'instances_val2017.json'), 'w'))
    print('| VOC -> COCO annotations transform finish.')
    print('Start copy images...')

    for img_name in train_img:
        shutil.copy(os.path.join(root_path, "VOCdevkit", "VOC2007", "JPEGImages", img_name),
                    os.path.join(root_path, "coco2017", 'train2017', img_name))
    print('| Train images copy finish.')

    for img_name in val_img:
        shutil.copy(os.path.join(root_path, "VOCdevkit", "VOC2007", "JPEGImages", img_name),
                    os.path.join(root_path, "coco2017", 'val2017', img_name))
    print('| Val images copy finish.')


if __name__ == '__main__':
    print("—" * 50)
    voc2coco()  # voc數據集轉換成coco數據集
    print("—" * 50)

1.3 數據集真實值可視化

利用coco數據集可視化數據集的ground truth。查看數據集中是否有臟數據,根據具體情況除掉。

mmdetection/data/dataset1/box_visiual.py

import json
import os
import random

import cv2

root_path = os.getcwd()
SAMPLE_NUMBER = 30  # 隨機挑選多少個圖片檢查,
id_category = {1: 'point'}  # 改成自己的類別


def visiual():
    # 獲取bboxes
    json_file = os.path.join(root_path, 'coco2017', 'annotations', 'instances_train2017.json')  # 如果想查看驗證集,就改這里
    data = json.load(open(json_file, 'r'))
    images = data['images']  # json中的image列表,

    # 讀取圖片
    for i in random.sample(images, SAMPLE_NUMBER):  # 隨機挑選SAMPLE_NUMBER個檢測
        # for i in images:                                        # 整個數據集檢查
        img = cv2.imread(os.path.join(root_path, 'coco2017', 'train2017',
                                      i['file_name']))  # 改成驗證集的話,這里的圖片目錄也需要改,train2017 -> val2017
        bboxes = []  # 獲取每個圖片的bboxes
        category_ids = []
        annotations = data['annotations']
        for j in annotations:
            if j['image_id'] == i['id']:
                bboxes.append(j["bbox"])
                category_ids.append(j['category_id'])

        # 生成錨框
        for idx, bbox in enumerate(bboxes):
            left_top = (int(bbox[0]), int(bbox[1]))  # 這里數據集中bbox的含義是,左上角坐標和右下角坐標。
            right_bottom = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))  # 根據不同數據集中bbox的含義,進行修改。
            cv2.rectangle(img, left_top, right_bottom, (0, 255, 0), 1)  # 圖像,左上角,右下坐標,顏色,粗細
            cv2.putText(img, id_category[category_ids[idx]], left_top, cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.4,
                        (255, 255, 255), 1)
            # 畫出每個bbox的類別,參數分別是:圖片,類別名(str),坐標,字體,大小,顏色,粗細
        # cv2.imshow('image', img)                                          # 展示圖片,
        # cv2.waitKey(1000)
        cv2.imwrite(os.path.join('visiual', i['file_name']), img)  # 或者是保存圖片
    # cv2.destroyAllWindows()


if __name__ == '__main__':
    print('—' * 50)
    os.mkdir('visiual')
    visiual()
    print('| visiual completed.')
    print('| saved as ', os.path.join(os.getcwd(), 'visiual'))
    print('—' * 50)

到這里我們的數據集就准備好了,第一大步完成,開始第二步。

2. 配置文件的處理

配置文件的處理,我們主要在work_dirs目錄下,如果在你 mmdetection/ 目錄下沒有 work_dirs 目錄的話,新建一個文件夾,然后我們在 work_dirs/ 目錄下新建一個自己的項目文件夾,例如圖中 dataset1。然后我們在 dataset1/ 目錄下見一個python文件,用於生成配置文件。

2.1 生成配置文件

先生成一個我們的配置文件,然后我們再在配置文件中做詳細修改。

mmdetection/work_dirs/dataset1/create_config.py

import os
import random
import numpy as np
import torch
from mmcv import Config
from mmdet.apis import set_random_seed

# from mmcv.ops import get_compiling_cuda_version, get_compiler_version
# print(torch.__version__, torch.cuda.is_available())
# print(get_compiling_cuda_version())
# print(get_compiler_version())

"""
設置隨機種子
"""
seed = 7777

"""Sets the random seeds."""
set_random_seed(seed, deterministic=False)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)

job_num = '1'  # 根據我的經驗,設置一個job編號方便管理。
model_name = f'cascade_rcnn_r50_fpn_1x_job{job_num}'  # 改成自己要使用的模型名字
work_dir = os.path.join(os.getcwd(), model_name)  # 訓練過程中,保存文件的路徑,不用動。
baseline_cfg_path = "../../configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py"  # 改成自己要使用的模型的路徑
cfg_path = os.path.join(work_dir, model_name + '.py')  # 生成的配置文件保存的路徑

train_data_images = os.getcwd() + '/../../data/mchar/mchar_train'  # 改成自己訓練集圖片的目錄。
val_data_images = os.getcwd() + '/../../data/mchar/mchar_val'  # 改成自己驗證集圖片的目錄。
test_data_images = os.getcwd() + '/../../data/mchar/mchar_test'  # 改成自己測試集圖片的目錄。

# File config
num_classes = 1  # 改成自己的類別數。
classes = ("point",)  # 改成自己的類別
# 去找個網址里找你對應的模型的網址: https://github.com/open-mmlab/mmdetection/blob/master/README_zh-CN.md
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco/cascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth'

train_ann_file = os.getcwd() + '/../../data/mchar/instances_train2017.json'  # 修改為自己的數據集的訓練集json
val_ann_file = os.getcwd() + '/../../data/mchar/instances_val2017.json'  # 修改為自己的數據集的驗證集json

# Train config              # 根據自己的需求對下面進行配置
gpu_ids = [1]  # 沒啥用,后邊用官方的工具進行訓練,這里無所謂。
total_epochs = 30  # 改成自己想訓練的總epoch數
batch_size = 2 ** 2  # 根據自己的顯存,改成合適數值,建議是2的倍數。
num_worker = 2  # 比batch_size小,就行
log_interval = 100  # 日志打印的間隔
checkpoint_interval = 8  # 權重文件保存的間隔
evaluation_interval = 1  # 驗證的間隔,這個一般不用動
lr = 0.01 / 2  # 學習率

"""
制作mmdetection的cascade配置文件
"""


def create_mm_config():
    cfg = Config.fromfile(baseline_cfg_path)

    cfg.work_dir = work_dir

    # Set seed thus the results are more reproducible
    cfg.seed = seed

    # You should change this if you use different model
    cfg.load_from = load_from

    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    print("| work dir:", work_dir)

    # Set the number of classes
    for head in cfg.model.roi_head.bbox_head:
        head.num_classes = num_classes

    cfg.gpu_ids = gpu_ids

    cfg.runner.max_epochs = total_epochs  # Epochs for the runner that runs the workflow
    cfg.total_epochs = total_epochs

    # Learning rate of optimizers. The LR is divided by 8 since the config file is originally for 8 GPUs
    cfg.optimizer.lr = lr

    ## Learning rate scheduler config used to register LrUpdater hook
    cfg.lr_config = dict(
        policy='CosineAnnealing',
        # The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.
        by_epoch=False,
        warmup='linear',  # The warmup policy, also support `exp` and `constant`.
        warmup_iters=500,  # The number of iterations for warmup
        warmup_ratio=0.001,  # The ratio of the starting learning rate used for warmup
        min_lr=1e-07)

    # config to register logger hook
    cfg.log_config.interval = log_interval  # Interval to print the log

    # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
    cfg.checkpoint_config.interval = checkpoint_interval  # The save interval is 1

    cfg.dataset_type = 'CocoDataset'  # Dataset type, this will be used to define the dataset
    cfg.classes = classes

    cfg.data.train.img_prefix = train_data_images
    cfg.data.train.classes = cfg.classes
    cfg.data.train.ann_file = train_ann_file
    cfg.data.train.type = 'CocoDataset'

    cfg.data.val.img_prefix = val_data_images
    cfg.data.val.classes = cfg.classes
    cfg.data.val.ann_file = val_ann_file
    cfg.data.val.type = 'CocoDataset'

    cfg.data.test.img_prefix = val_data_images
    cfg.data.test.classes = cfg.classes
    cfg.data.test.ann_file = val_ann_file
    cfg.data.test.type = 'CocoDataset'

    cfg.data.samples_per_gpu = batch_size  # Batch size of a single GPU used in testing
    cfg.data.workers_per_gpu = num_worker  # Worker to pre-fetch data for each single GPU

    # The config to build the evaluation hook, refer to https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/evaluation/eval_hooks.py#L7 for more details.
    cfg.evaluation.metric = 'bbox'  # Metrics used during evaluation

    # Set the epoch intervel to perform evaluation
    cfg.evaluation.interval = evaluation_interval

    cfg.evaluation.save_best = 'bbox_mAP'

    cfg.log_config.hooks = [dict(type='TextLoggerHook')]

    print("| config path:", cfg_path)
    # Save config file for inference later
    cfg.dump(cfg_path)
    # print(f'CONFIG:\n{cfg.pretty_text}')


if __name__ == '__main__':
    print("—" * 50)
    create_mm_config()
    print("—" * 50)

2.2 修改配置文件

一些沒有在生成配置文件中設置的,我們直接打開配置文件,進行修改,例如下邊的anchor_generator的一些參數。

mmdetection/work_dirs/dataset1/cascade_rcnn_r50_fpn_1x_job1/cascade_rcnn_r50_fpn_1x_job1.py

3. 開始訓練

在mmdetection根目錄下,也就是 mmdetection/ 目錄用命令行運行,可以等程序運行起來后,看顯存占用,然后調節batch_size。

單GPU訓練

模板

python tools/train.py ${配置文件} --gpu-ids ${gpu id}

樣例:我想利用第二張顯卡訓練,就將 –gpu-ids 設置為1

python tools/train.py work_dirs/dataset1/cascade_rcnn_r50_fpn_1x_job1/cascade_rcnn_r50_fpn_1x_job1.py --gpu-ids 1

多GPU訓練

模板

bash tools/dist_train.sh ${配置文件} ${gpu 數量}

樣例:我用兩張顯卡一起訓練

bash tools/dist_train.sh work_dirs/dataset1/cascade_rcnn_r50_fpn_1x_job1/cascade_rcnn_r50_fpn_1x_job1.py 2

4. 可視化模型的輸出

訓練完后可以看一下自己模型推理的結果,看一下效果。在我們的工作目錄下創建一個 visiual.py 文件。

mmdetection/work_dirs/dataset1/cascade_rcnn_r50_fpn_1x_job2/visiual.py

import glob
import os
import shutil

import cv2
import cv2.cv2
import numpy as np
from mmdet.apis import inference_detector, init_detector

root_path = os.getcwd()
job_num = '2'  # 根據job數,進行修改
model_name = f'cascade_rcnn_r50_fpn_1x_job{job_num}.py'  # 改為自己的模型名
test_images_path = os.path.join(root_path, '../../../data/dataset1/coco2017/train2017/')  # 改為自己想要推理的圖片
save_dir = 'results_visiual_job' + job_num  # 可視化結果保存的路徑

classes = ("point",)  # 改成自己的類別
image_id = (1,)  # 類別對應id
SCORE_THRESH = 0.1  # 置信度閾值,只顯示置信度>=閾值的bbox
DEVICE = 'cuda:0'  # 顯卡


def inference_res(model, images_filename):
    results = []
    for img_name in images_filename:
        img = test_images_path + img_name
        result = inference_detector(model, img)
        for i in range(len(result)):
            for j in result[i]:
                j = np.array(j).tolist()
                if j[-1] >= SCORE_THRESH:
                    # 這里注意原來是xmin, ymin, xmax, ymax.
                    # 根據需求進行保存,這里我就保存xmin, ymin, xmax, ymax.
                    pred = {'image_id': img_name,
                            'category_id': 1,  # 因為我只有一個類,推理出來的result只有置信度和bbox,
                            # 沒有類別信息,這里根據自己的需求改
                            'bbox': [j[0], j[1], j[2], j[3]],
                            'score': j[-1]}
                    results.append(pred)
    return results


def visiual(results):
    img_names = os.listdir(test_images_path)
    # lst = []
    for i in img_names:
        img = cv2.imread(os.path.join(test_images_path, i))
        for j in results:
            if j['image_id'] == i:
                if j['score'] >= SCORE_THRESH:
                    xmin = int(j['bbox'][0])
                    ymin = int(j['bbox'][1])
                    xmax = int(j['bbox'][2])
                    ymax = int(j['bbox'][3])
                    cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
                    # cv2.cv2.putText(img,str(round(j['score'], 3)),(xmin,ymin ),cv2.cv2.FONT_HERSHEY_COMPLEX, 0.7, (255, 255, 255), 3)

        cv2.imwrite(save_dir + '/' + i, img)  # 將結果保存到文件夾
    #     lst.append(img)
    # lst = cycle(lst)
    # key = 0
    # while key & 0xFF !=27:
    #     cv2.imshow("image",next(lst))
    #     key = cv2.waitKey(3000)
    # cv2.cv2.destroyAllWindows()     # esc結束可視化


if __name__ == '__main__':
    print("—" * 50)
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)
    os.makedirs(save_dir)
    best_epoch_filepath = glob.glob('best' + '*')[0]  # best_bbox_mAP_epoch_9.pth
    config = os.path.join(root_path, model_name)
    checkpoint = os.path.join(root_path, best_epoch_filepath)

    print('| config: ', config)
    print('| checkpoint: ', checkpoint)

    model = init_detector(config, checkpoint, device=DEVICE)
    images_filename = os.listdir(test_images_path)
    results = inference_res(model, images_filename)

    visiual(results)  # 可視化測試數據集
    print('| image save dir:', save_dir)
    print('| Visiual complete.')
    print("—" * 50)

5. gradio做到網頁

我相信面前的你肯定也是個願意折騰的小伙伴,那就讓我們把它做到網頁上,過程很簡單。

更多的配置,參考gradio官網文檔 【Gradio Getting Started】 【Gradio Docs】

5.1 安裝gradio

pip install gradio

5.2 編寫實例

我是在 mmdetection/ 目錄下新建了一個 gradio.py 文件。運行后就可以看到控制台輸出了一個網址,點進去,就可以上傳圖片,然后可以推理了。

mmdetection/gradio.py

import os

import gradio as gr
import numpy as np
from cv2 import cv2
from mmdet.apis import inference_detector, init_detector

root_path = os.getcwd()

classes = ("apple",)  # 改為自己的類別名
image_id = (1,)  # 類別名對應的id
SCORE_THRESH = 0.2  # 置信度閾值
DEVICE = 'cuda:0'  # 用那個顯卡推理
config_path = "./work_dirs/xuliandi/cascade_rcnn_r50_fpn_1x/cascade_rcnn_r50_fpn_1x.py"  # 配置文件,改為自己的
checkpoint_path = "./work_dirs/xuliandi/cascade_rcnn_r50_fpn_1x/best_bbox_mAP_epoch_9.pth"  # 權重文件,改為自己的

config = config_path
checkpoint = checkpoint_path
model = init_detector(config, checkpoint, device=DEVICE)


def inference_res(model, image_input):
    results = []
    result = inference_detector(model, image_input)
    for i in range(len(result)):
        for j in result[i]:
            j = np.array(j).tolist()
            if j[-1] >= SCORE_THRESH:
                pred = {'bbox': [j[0], j[1], j[2], j[3]],
                        'score': j[-1]}
                results.append(pred)
    return results


def detect_image(image_input):
    results = inference_res(model, image_input)

    for i in results:
        xmin = int(i['bbox'][0])
        ymin = int(i['bbox'][1])
        xmax = int(i['bbox'][2])
        ymax = int(i['bbox'][3])
        cv2.rectangle(image_input, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)  # 畫bbox
    return image_input


if __name__ == '__main__':
    gr.Interface(fn=detect_image, inputs="image", outputs="image", capture_session=True).launch()

⭐ 完結撒花


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM