模型訓練流程
-
從
tools/train.py
開始:- 一通讀取 cfg ,初步設置一些基本參數,log 參數;
- build 模型,build 數據集 (有多少個 workflow 就 build 多少個數據集,比如如果 train 的過程中還進行 val 則表示有 2 個 workflow) ;
- 最后調用
mmdet.apis.train_detector
,傳入剛才 build 好的 model,datasets,配置參數等。
-
進入
mmdet.apis.train_detector
:- 為每一個 workflow 對應的 dataset , build data_loader ( data_loader 繼承自 pytorch 自帶的 DataLoader 類,這里先簡單理解,其將 dataset 里面 data sample 包裝成 data batch ,作為生成器的形式,每次用 for 迭代 load batch ) ;
- 判斷是否是分布式訓練,分布式訓練則用
MMDistributedDataParallel
封裝 model,單 GPU 訓練則MMDataParallel
; - build optimizer;
- 重頭戲: runner ,runner 可以理解為操控整個訓練過程的核心。首先,先跳過中間那一堆對 runner hook 的設置,直接看到最后,調用了
runner.run()
,訓練從此處開始。
-
runner 是
EpochBasedRunner
類的實例,進入EpochBasedRunner
類的定義,可以看到最主要的是 run 方法:def run(self, data_loaders, workflow, max_epochs, **kwargs): #... while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( 'mode in workflow must be a str, but got {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break epoch_runner(data_loaders[i], **kwargs)
workflow
變量的注釋:workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,最后 4 行是重點,根據每個 workflow 的 mode 和 epochs 調用 epochs 次相應的函數,比如:
for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: break #when mode == 'train' # `epoch_runner(data_loaders[i], **kwargs)` == self.train(data_loaders[i], **kwargs)
一個 epoch 相當於遍歷一遍數據集的所有數據。
接下來看看 EpochBasedRunner.train() :
- 設置基本參數
- 在一些關鍵節點的前后調用了 hook :
before_train_epoch
,before_train_iter
,after_train_iter
,after_train_epoch
。執行反向傳播是在after_train_iter
處。 (先不糾結 hook 是個啥) - data_loader 為生成器,用 for 迭代取出 1 個 batch 的數據,進入逐個 iter 的訓練:
- 如果有為該 Runner 指定 batch processor,則調用。
- 否則,直接調用模型的 train_step,傳入訓練數據。
hook
hook 的作用是對一些中間結果做相應的操作,比如打印 log ,比如在 training 過程中的 evaluation 等等。
下面解析一下配置文件中出現的 TensorboardLoggerHook
先從 EpochBasedRunner 如何使用 hook 看起:
-
EpochBasedRunner.register_hook()
- 注冊 hook 到 runner,根據 hook cfg build 相應的 hook 實例,放到 runner 的 hook 隊列中。hook 隊列是一個優先級隊列,優先級可以在傳入 hook 的時候指定。
-
EpochBasedRunner.call_hook(fn_name)
- 使用 hook ,根據需要調用的函數名
fn_name
,調用每個 hook 里的同名函數,因為 runner 緩存着中間結果,需要將 runner 作為參數傳進去。
- 使用 hook ,根據需要調用的函數名
TensorboardLoggerHook
-
該類主要的作用是將每次 iter 或 epoch 完記錄訓練結果到 tensorboard (即寫到 summary 文件里)
-
TensorboardLoggerHook.after_train_iter(runner)
該函數做了什么?判斷是否達到 interval,比如在配置文件中指定了每 50 個 iter 才 log 訓練結果,如果達到 50 個 iter,則對 50 個 iter 的結果求平均,再調用自己的 log 函數。 50 個 iter 的結果存放在 runner.log_buffer 里。
-
TensorboardLoggerHook.log(runner)
將 runner.log_buffer 里的結果值,通過 summary_writer 寫到 summary 文件。
除了 Logger 這種形式的 hook 之外,還有其他一些功能也以 hook 的形式實現,比如 optimizer 對應的 OptimizerHook
,或者 training 過程中的 eval 也是通過 EvaluationHook
調用。