# AssertionError: The `num_classes` (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 80) in CocoDataset


我看很多人都遇到了這個問題,有很多解決了的。我就把這篇博文再完善一下,讓大家對mmdetection使用得心應手。


mmdetection訓練自己的數據集時報錯 ⚠️ :

# AssertionError: The `num_classes` (3) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 80) in CocoDataset

你可能已經修改了以下兩個文件,但是還是報錯:

mmdetection-master\mmdet\core\evaluation\class_names.py

mmdetection-master\mmdet\datasets\coco.py

意思就是你指定的類別(3種)與CocoDataset的類別(80種)不匹配。

如果是報錯翻過來的話,也就是你指定的類別(80種)與CocoDataset的類別(3種)不匹配。一定是配置文件里設置錯了,去你的配置文件搜索num_classes,然后修改好。


廢話不多說,直接上方法。有以下幾種方法【經過我多次使用后,推薦第四種,方便的很】:

1️⃣ 是修改最少的,假設你有2個類,你就把上邊兩處地方,前2個類替換成你的類別。方法比較簡單,但是可能存在隱患。【不推薦】

2️⃣ 第二種方法就是修改完 class_names.py 和 voc.py 之后一定要重新編譯代碼(運行python setup.py install),再進行訓練。

我試了,有時候可以,有時候不行,可以嘗試一下。

參考:

新版 MMDetection V2.3.0訓練測試筆記 - it610.com

mmdetectionV2.x版本 訓練自己的VOC數據集_桃子醬momo的博客-CSDN博客

3️⃣ 第三種方法,我之前使用的方法,其實跟重新編譯一樣,重新編譯的原因就是因為環境里的源文件沒有修改,所以你才會報錯。mmdetection-master目錄下只是一些python文件,真正運行程序時,運行的還是環境里的源文件,因為我們直接去環境里修改源文件。

假設我的conda環境名為conda_env_name,因此去下面的目錄下,分別修改兩個文件:

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\core\evaluation\class_names.py

\anaconda3\envs\conda_env_name\lib\python3.7\site-packages\mmdet\datasets\coco.py

在conda環境里把這兩個文件里的類別修改了,就可以了,這一招一定可以。

4️⃣ 第四種方法,更簡單,更方便,我現在使用的方法。直接在mmdetection配置文件中指定好所有要指定的東西,因為在mmdetection中配置文件的參數值優先級是最高的,所以不用管環境里有沒有修改,配置文件里修改了,就可以了。我寫了個腳本,把腳本放到mmdetection根目錄,根據自己要用的模型,把腳本中的變量都改成自己的。

我以cascade_mask_rcnn_r101為例:

# 在mmdetection的根目錄下運行,如果報錯:沒有那個參數,就把create_mm_config中那個參數賦值給注釋掉。生成配置文件后,直接修改配置文件就可以了。
import os
from mmcv import Config

#################################  下邊是要修改的內容   ####################################

root_path = os.getcwd()
model_name = 'cascade_mask_rcnn_r101'  # 改成自己要使用的模型名字
work_dir = os.path.join(root_path, "work_dirs", model_name)  # 訓練過程中,保存日志權重文件的路徑,。
baseline_cfg_path = os.path.join('configs', 'cascade_rcnn', 'cascade_mask_rcnn_r101_fpn_mstrain_3x_coco.py')
# 改成自己要使用的模型的配置文件路徑
save_cfg_path = os.path.join(work_dir, 'config.py')  # 生成的配置文件保存的路徑

train_data_images = os.path.join(root_path, 'data', 'train', 'images')  # 改成自己訓練集圖片的目錄。
val_data_images = os.path.join(root_path, 'data', 'train', 'images')  # 改成自己驗證集圖片的目錄。
test_data_images = os.path.join(root_path, 'data', 'val', 'images')  # 改成自己測試集圖片的目

train_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_train.json')  # 修改為自己的數據集的訓練集json
val_ann_file = os.path.join(root_path, 'data', 'train', 'annotations', 'new_val.json')  # 修改為自己的數據集的驗證集json
test_ann_file = os.path.join(root_path, 'data', 'val', 'annotations', 'new_test.json')  # 修改為自己的數據集的驗證集json錄。

# 去找個網址里找你對應的模型的網址: https://github.com/open-mmlab/mmdetection/blob/master/README_zh-CN.md
load_from = os.path.join(work_dir, 'checkpoint.pth')  # 修改成自己的checkpoint.pth路徑

# File config
num_classes = 50  # 改成自己的類別數。
classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
           '11', '12', '13', '14', '15', '16', '17', '18', '19',
           '20', '21', '22', '23', '24', '25', '26', '27', '28',
           '29', '30', '31', '32', '33', '34', '35', '36', '37',
           '38', '39', '40', '41', '42', '43', '44', '45', '46',
           '47', '48', '49', '50')  # 改成自己的類別

###############  下邊一些參數包含不全,可以在生成的配置文件中再對其他參數進行修改    #####################

# Train config              # 根據自己的需求對下面進行配置
gpu_ids = range(0, 1)  # 改成自己要用的gpu
gpu_num = 1
total_epochs = 20  # 改成自己想訓練的總epoch數
batch_size = 2 ** 1  # 根據自己的顯存,改成合適數值,建議是2的倍數。
num_worker = 1  # 比batch_size小,就行
log_interval = 300  # 日志打印的間隔
checkpoint_interval = 7  # 權重文件保存的間隔
lr = 0.02 * batch_size * gpu_num / 16  # 學習率
ratios = [0.5, 1.0, 2.0]
strides = [4, 8, 16, 32, 64]

cfg = Config.fromfile(baseline_cfg_path)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)

cfg.work_dir = work_dir
print("Save config dir:", work_dir)

# swin和mmdetection的訓練集配置不在一個地方,那個不報錯用哪個:
cfg.classes = classes
# mmdetection用這個:
cfg.data.train.img_prefix = train_data_images
cfg.data.train.classes = classes
cfg.data.train.ann_file = train_ann_file
# swin用這個,注釋上邊那個
# cfg.data.train.dataset.img_prefix = train_data_images
# cfg.data.train.dataset.classes = classes
# cfg.data.train.dataset.ann_file = train_ann_file

cfg.data.val.img_prefix = val_data_images
cfg.data.val.classes = classes
cfg.data.val.ann_file = val_ann_file

cfg.data.test.img_prefix = test_data_images
cfg.data.test.classes = classes
cfg.data.test.ann_file = test_ann_file

cfg.data.samples_per_gpu = batch_size
cfg.data.workers_per_gpu = num_worker
cfg.log_config.interval = log_interval

# 有些配置文件num_classes可能不在這個地方,生成之后去配置文件里搜索一下,看看都修改了沒
for head in cfg.model.roi_head.bbox_head:
    head.num_classes = num_classes
if "mask_head" in cfg.model.roi_head:
    cfg.model.roi_head.mask_head.num_classes = num_classes

cfg.load_from = load_from
cfg.runner.max_epochs = total_epochs
cfg.total_epochs = total_epochs
cfg.optimizer.lr = lr
cfg.checkpoint_config.interval = checkpoint_interval
cfg.model.rpn_head.anchor_generator.ratios = ratios
cfg.model.rpn_head.anchor_generator.strides = strides

cfg.dump(save_cfg_path)
print(save_cfg_path)
print("—" * 50)
print(f'CONFIG:\n{cfg.pretty_text}')
print("—" * 50)

生成配置文件后,路徑在 ./work_dirs/cascade_mask_rcnn_r101/config.py,在mmdetection根目錄下,又可以愉快的進行訓練了。

訓練命令(在mmdetection根目錄):4GPU訓練

./tools/dist_train.sh work_dirs/cascade_mask_rcnn_r101/config.py 4

⭐ 最終也功夫不負有心人,解決掉了這個bug,寫此博客,以幫助大家少走彎路。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM