DETR訓練自己的數據集---使用方法


參考:1、[GitHub](https://github.com/DataXujing/detr_transformer)

           2、[Bilibili視頻](https://www.bilibili.com/video/BV1GC4y1h77h)


1、拷貝代碼
```
git clone https://github.com/facebookresearch/detr.git
```

2、創建新的虛擬環境(推薦)
```
conda create -n detr python=3.7
conda activate detr
```

3、安裝依賴庫
安裝PyTorch 1.5+ and torchvision 0.6+,安裝scipy
```
conda install -c pytorch pytorch torchvision
conda install cython scipy
```

安裝pycocotools (for evaluation on COCO):
```
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
```

4、數據准備(格式、路徑如下所示)
```
├─annotations # 標注的json文件,coco類型的標注
├─instances_train.json
├─instances_val.json
├─train # 訓練圖像的存放地址
├─xxx.jpg
├─val # 驗證圖像的存放地址
└─xxxx.jpg
```

5、下載預訓練模型,並修改類別數
修改配置文件change.py,將num_class改為“類別+1”
```
import torch

pretrained_weights = torch.load("./detr-r50-e632da11.pth")

num_class = 3 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)

torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)
```

```
python change.py # 在項目文件夾下生成detr_r50_{class_num}.pth
```

6、修改./models/detr.py中的build方法
```
def build(args):
num_classes = 3 if args.dataset_file != 'coco' else 3+1 #<--類別數 + 1
if args.dataset_file == "coco_panoptic": # 全景分割
num_classes = 3+1 # <-------------
device = torch.device(args.device)
```

7、按照自己需要修改main.py文件

8、訓練
```
python main.py --dataset_file "coco" --coco_path "/myData/coco" --epoch 500 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir="outputs" --resume="detr_r50_4.pth"

python main.py --dataset_file "coco" --coco_path "/data1/hzy/COCO2007" --epoch 50 --batch_size=4 --num_workers=4 --output_dir="outputs_1" --resume="detr_r50_55.pth"
```

9、測試
```
inference_img.py
```


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM