mmdetection2損失為nan


好久沒用mmdetection了,今天用visdrone數據集訓練一個cascade-rcnn-r101模型,損失總是會出現nan,先考率學習率是否太高的問題,學習率分別設為0.02*batchsize/16,0.000001,0,仍然出現損失為nan。所以懷疑是數據問題,可能存在無效的目標框(目標框的左下角坐標<=右上角坐標),具體判別代碼:

import xml.etree.ElementTree as ET
import os
xml_root = "./data"
new_xml_root = "./data"
image_root = "./data"
xml_name_list = sorted(os.listdir(xml_root))
def check_bbox():
    if not os.path.exists(new_xml_root):
        os.makedirs(new_xml_root)

    for xml_name in xml_name_list:
        xml_path = os.path.join(xml_root, xml_name)
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for obj in root.findall("object"):
            bnd_box = obj.find("bndbox")
            bbox = [
                int(float(bnd_box.find("xmin").text)),
                int(float(bnd_box.find("ymin").text)),
                int(float(bnd_box.find("xmax").text)),
                int(float(bnd_box.find("ymax").text)),
            ]

            if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]:
                print("bbox[0] >= bbox[2] or bbox[1] >= bbox[3]", bbox, xml_name)
check_bbox()

現記錄使用mmdetection2訓練visdrone的具體過程

  • 處理visdrone數據,將其txt標簽轉為VOC格式
import os
from PIL import Image

root_dir = "/mnt/A/pengyuan/data/Visd2019/trainval/"
annotations_dir = root_dir+"annotations/"
image_dir = root_dir + "images/"
xml_dir = root_dir+"Annotations/"  
class_name = ['ignored regions','pedestrian','people','bicycle','car','van','truck','tricycle','awning-tricycle','bus','motor','others']

for filename in os.listdir(annotations_dir):
    fin = open(annotations_dir+filename, 'r')
    image_name = filename.split('.')[0]
    img = Image.open(image_dir+image_name+".jpg")
    xml_name = xml_dir+image_name+'.xml'
    with open(xml_name, 'w') as fout:
        fout.write('<annotation>'+'\n')
        
        fout.write('\t'+'<folder>VOC2007</folder>'+'\n')
        fout.write('\t'+'<filename>'+image_name+'.jpg'+'</filename>'+'\n')
        
        fout.write('\t'+'<source>'+'\n')
        fout.write('\t\t'+'<database>'+'VisDrone2019 Database'+'</database>'+'\n')
        fout.write('\t\t'+'<annotation>'+'VisDrone2019'+'</annotation>'+'\n')
        fout.write('\t\t'+'<image>'+'flickr'+'</image>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Unspecified'+'</flickrid>'+'\n')
        fout.write('\t'+'</source>'+'\n')
        
        fout.write('\t'+'<owner>'+'\n')
        fout.write('\t\t'+'<flickrid>'+'Haipeng Zhang'+'</flickrid>'+'\n')
        fout.write('\t\t'+'<name>'+'Haipeng Zhang'+'</name>'+'\n')
        fout.write('\t'+'</owner>'+'\n')
        
        fout.write('\t'+'<size>'+'\n')
        fout.write('\t\t'+'<width>'+str(img.size[0])+'</width>'+'\n')
        fout.write('\t\t'+'<height>'+str(img.size[1])+'</height>'+'\n')
        fout.write('\t\t'+'<depth>'+'3'+'</depth>'+'\n')
        fout.write('\t'+'</size>'+'\n')
        
        fout.write('\t'+'<segmented>'+'0'+'</segmented>'+'\n')

        for line in fin.readlines():

            line = line.split(',')
            if int(line[5])==0 or int(line[5])==11:
                continue
            fout.write('\t'+'<object>'+'\n')
            print(line)
            print(image_name)
            fout.write('\t\t'+'<name>'+class_name[int(line[5])]+'</name>'+'\n')
            fout.write('\t\t'+'<pose>'+'Unspecified'+'</pose>'+'\n')
            fout.write('\t\t'+'<truncated>'+line[6]+'</truncated>'+'\n')
            fout.write('\t\t'+'<difficult>'+str(int(line[7]))+'</difficult>'+'\n')
            fout.write('\t\t'+'<bndbox>'+'\n')
            fout.write('\t\t\t'+'<xmin>'+line[0]+'</xmin>'+'\n')
            fout.write('\t\t\t'+'<ymin>'+line[1]+'</ymin>'+'\n')
            # pay attention to this point!(0-based)
            fout.write('\t\t\t'+'<xmax>'+str(int(line[0])+int(line[2])-1)+'</xmax>'+'\n')
            fout.write('\t\t\t'+'<ymax>'+str(int(line[1])+int(line[3])-1)+'</ymax>'+'\n')
            fout.write('\t\t'+'</bndbox>'+'\n')
            fout.write('\t'+'</object>'+'\n')
             
        fin.close()
        fout.write('</annotation>')
  • 設置mmdetection2訓練自己的數據集

    • 更改config.py
      這里我選擇的是./configs/cascade_rcnn/cascade_rcnn_r101_fpn_1x_coco.py的config,發現其是調用./configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py。
      由於本次使用VOC數據格式,故更改cascade_rcnn_r50_fpn_1x_coco.py的datasets為../-base-/datasets/voc0712.py

    • 更改datasets的config.py
      數據config在configs/-base-/datasets/voc0712.py,中,需要更改輸入尺寸,數據集路徑,batchsize,具體如下圖:

    • 更改學習率
      學習率在configs/-base-/schedules/schedule_1x.py中更改,單卡訓練學習率推薦值為0.02*batchsize/16,具體如下圖:

    • 更改類別數
      cascade_rcnn_r101_fpn_1x_coco.py的類別數在/configs/-base-/models/cascade_rcnn_r50_fpn.py中更改。將其中的num_classes改為需要的類別數,mmdetection2不需要再加1了,此處設為10

    • 更改類別名
      有兩個地方需要更改類別名,首先是mmdet/datatsets/voc.py中的class VOCDatasets,如下圖

      另外是計算mAP的地方需要更改類別,具體在/mmdet/core/evaluation/class_names.py

  • 訓練
    為了方便訓練,寫一個小腳本mytrain.sh來訓練

#!/bin/bash
CUDA_VISIBLE_DEVICES=6 python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py 

然后運行sh mytrain.sh

  • 測試
    為了方便測試,寫一個小腳本mytest.sh來測試
#!/bin/bash
CUDA_VISIBLE_DEVICES=6 \
python tools/test.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_12.pth --out results.pkl --eval bbox --show \
python tools/voc_eval.py results.pkl ./configs/my_data.py
  • 有時候改完config后還是不能運行報錯:of MMDataParallel does not matches the length of CLASSES 20) in RepeatDataset。很有可能是需要重構代碼:

運行

python setup.py develop


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM