之前的博客分享過YOLOv3/v4算法訓練自己數據的方式,也介紹過模型加速的一般流程,以及本篇模型加速使用的通道剪枝的方式。沒做筆記的同學可以回看之后再來閱讀本篇噢!
此處再來回顧一下模型加速的一般流程。
因此,此處省略基礎模型訓練的環節。
本項目實踐全仰仗github:https://github.com/tanluren/yolov3-channel-and-layer-pruning,github中ReadMe寫的比較清楚,各位同學可以參考ReadMe,也可以參考本篇!
稀疏訓練
稀疏訓練建立在基礎模型訓練的基礎上,也就是YOLOv3/v4原型的訓練。(建議使用YOLOv3, YOLOv4在稀疏訓練中可能會存在一些問題。可以直接剪枝!)
去除.weights中的epoch信息
將訓練后的模型放到github項目對應的文件夾中,由於Darknet訓練得到的.weights文件是帶有epoch記錄的,因此,需要先通過轉換,將.weights中的epoch信息去掉(這里去掉的含義相當於置零),便於進行稀疏訓練。
如果基於.weights(帶有epoch的)的情況下,需要在訓練時將迭代次數設為.weights帶有的epoch + 稀疏訓練的代數。如果低於.weights帶有的epoch,訓練時模型會直接保存權重並停止。
本篇使用將epoch去掉的方式,將epoch去掉的方法在models.py的convert方法,該方法需要傳入兩個參數,其一是cfg文件,其二是權重文件。
python -c "from models import *; convert('cfg/yolov3-voc.cfg', 'yolov3-voc.weights')"
執行上述代碼,會生成convert.pt權重文件(pytorch版的)。
之后,便可以使用該權重文件進入稀疏訓練的環節。
稀疏訓練
github上給出了多種剪枝的方式,也給出了相應的調用方法。本篇使用--prune 1適用其他剪枝策略的方式。執行下述代碼,將對應的配置文件名稱,訓練文件名稱,權重文件名稱替換成自己的;batch-size和device的設置根據自己電腦的配置進行修改,其余參數可根據個人喜好進行修改。如果是yolov4,改成yolov4對應的文件即可。
python train.py --cfg cfg/yolov3-voc.cfg --data data/voc.data --weights yolov3-voc_uav.pt --batch-size 20 --epochs 480 -sr --s 0.005 --prune 1 --device 0
代碼中每隔10代保存一次權重文件,訓練代數設置的比較大的話會十分占硬盤空間的。因此,本篇將保存代數設置為80。
訓練完成后,會在weights文件夾中生成對應的權重文件,注意區分保存best.pt和last.pt,極容易在后續訓練過程中被替換掉。
同時,可以打開runs文件夾,對相應的events文件使用TensorBoard觀察訓練中的情況:
tensorboard --logdir ./
系數訓練的情況對剪枝會造成比較大的影響,因此,本篇在使用時設置了480 epoch。這種影響體現在相同的剪枝力度會得到不同的參數量。
剪枝
該過程是建立在稀疏訓練之后,本篇選用slim_prune.py的剪枝方式,使用的是通道剪枝,指令如下。
python slim_prune.py --cfg cfg/yolov3-voc.cfg --data data/voc.data --weights weights/last.pt --global_percent 0.85 --layer_keep 0.01
執行上述指令,會在cfg文件中生成相應的cfg文件,weights文件夾中生成相應的.weights文件。本篇記錄了yolov3剪枝力度從0.45-0.85的情況,如下圖所示,可以看出不同剪枝力度剪枝后,mAP,參數量,推理時間的對應關系。
這里有個比較有意思的事情,就是使用稀疏訓練生成的best.pt / last.pt分別進行不同力度的剪枝,會出現不一樣的效果:
對best.pt進行剪枝
對best.pt進行剪枝力度從0.50-0.85進行嘗試,結果如下圖所示,可以看出,使用best.pt進行剪枝,mAP隨着剪枝力度的增大而增大,呈現一定規律性。但剪枝后mAP始終低於原模型。
對last.pt進行剪枝
對last.pt進行剪枝力度從0.45-0.70進行嘗試。結果如下圖所示,可以看出,使用last.pt進行剪枝,剪枝力度仿佛對mAP沒有什么影響,掉點並不明顯。並且存在剪枝后mAP高於原模型的情況。
微調
微調意義在於對剪枝之后的模型恢復精度。也可以在微調的模型上加大數據集,使得剪枝后的模型泛化能力可以具有與原模型比肩的能力。
python train.py --cfg cfg/prune_0.85_keep_0.01_yolov3-voc.cfg --data data/voc.data --weights weights/prune_0.85_keep_0.01_last.weights --epochs 400 --batch-size 20 --device 0
微調后,保存的模型又是.pt文件,會生成和稀疏訓練一樣的best.pt和last.pt,很有可能會替代之前的文件,如果之前的文件放置在了weights文件目錄下。
.pt文件轉.weights文件
為了便於將生成的權重文件和配置文件放回Darknet中進行測試,需要將微調生成的.pt權重文件轉換為.weights文件。
由於剪枝過程是使用.pt文件為輸入,生成的是.weights文件,個人通過對slim_prune.py代碼的閱讀,從中摳出了.pt文件轉.weights文件的代碼,不過仍需要基於該github相關的文件。注意,不是單純執行下述代碼就可以實現.pt文件向.weights文件,需要基於github項目。
1 # -*- coding: utf-8 -*- 2 # @Time : 2020/7/7 上午9:22 3 # @Author : monologuesmw 4 # @Email : monologuesmw@163.com 5 # @File : pt2weights.py 6 # @Software: PyCharm 7 8 import torch 9 10 from models import * 11 from utils.prune_utils import * # 基於github項目中的兩個文件 12 13 import argparse 14 import numpy as np 15 16 if __name__ == "__main__": 17 parser = argparse.ArgumentParser() 18 parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') 19 # parser.add_argument('--data', type=str, default='data/coco.data', help='*.data file path') 20 parser.add_argument('--weights', type=str, default='weights/last.pt', help='sparse model weights') 21 # parser.add_argument('--global_percent', type=float, default=0.8, help='global channel prune percent') 22 # parser.add_argument('--layer_keep', type=float, default=0.01, help='channel keep percent per layer') 23 parser.add_argument('--img_size', type=int, default=416, help='inference size (pixels)') 24 opt = parser.parse_args() 25 26 img_size = opt.img_size 27 cfg_path = opt.cfg 28 pt_path = opt.weights 29 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 model = Darknet(opt.cfg, (img_size, img_size)).to(device) 31 assert pt_path 32 model.load_state_dict(torch.load(opt.weights, map_location=device)['model']) 33 34 compact_model_name = "my_test.weights" 35 save_weights(model, path=compact_model_name) 36 print("success save weights")
執行下述指令便可以得到相應的my_test.weights文件。
python pt2weights.py --cfg cfg/prune_0.85_keep_0.01_yolov3-voc.cfg --weights weights/best.pt
(訓練、測試數據部分來源於Kaggle)