MMDet
工作需要簡單看了一下源碼,主要側重訓練和推理的部分,涉及到的是Registry、Runner和Hook部分。
核心庫
核心庫有MMDetection、MMSegmentation、MMDetection3d、MMCV。
MMDetection3d: 支持3d目標檢測的模型和數據集
MMDetection & MMSegmentation: 支持常規的目標檢測和分割的模型
MMCV:MM系列的基礎庫。支持了:
- Universal IO APIs
- Image/Video processing
- Image and annotation visualization
- Useful utilities (progress bar, timer, ...)
- PyTorch runner with hooking mechanism
- Various CNN architectures
- High-quality implementation of common CUDA ops
核心組件
使用MMDet進行訓練和推理設計到的核心組件為MMCV/Registry和MMCV/Runner。還有一個非常重要的就是Hooks,可以在源碼中看到開發人員大量的使用了hooks。
Registry
Registry是用來實例化MMDet中所有對象的工具,包括模型、數據集和Optimizer等等。
整體流程
Examples:
1. 注冊類:類名->類 的映射
VOXEL_ENCODERS = Registry('voxel_encoder')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
...
2. 實例化類對象:解析config,實例化類
def build_voxel_encoder(cfg):
"""Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS)
def build(cfg, registry, default_args=None):
...
return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
...
args = cfg.copy()
obj_cls = registry.get(obj_type)
return obj_cls(**args)
核心組件
MMCV/utils/regitry.py 主要是用一個self._module_dict來存儲類名和類
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(type='ResNet'))
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
advanced useage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
注冊類
VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
FUSION_LAYERS = Registry('fusion_layer')
@VOXEL_ENCODERS.register_module()
class HardSimpleVFE(nn.Module):
build
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
# 注意這只是把一些細節的模塊拼在一起
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
Runner
MMCV/runner
從runner的目錄結構可以看出,runner主要負責的就是實現checkpoint、train、val、optimizer和hooks。
mmdet3d、mmdet、mmcv中train的調用關系可以總結為:Mmdet3d/train.py -> mmdet/train_detector() -> mmcv/runner.run() -> mmcv/epoch_base_runner.py/train()
代碼實現
mmcv/epoch_base_runner.py/train()
可以看出train函數已經開始進行dataloader遍歷訓練的過程了。其中也添加了很多hooks,這些都是在runner實例化的時候就已經register進runner中的,在EpochBasedRunner類的父類BaseRunner中有register_hook方法負責這件事。
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
mmcv/epoch_base_runner.py/run_iter()
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
Hook
- hook基類
mmcv/runner/hooks/hook.py
from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
def before_train_epoch(self, runner):
self.before_epoch(runner)
def before_val_epoch(self, runner):
self.before_epoch(runner)
def after_train_epoch(self, runner):
self.after_epoch(runner)
def after_val_epoch(self, runner):
self.after_epoch(runner)
def before_train_iter(self, runner):
self.before_iter(runner)
def before_val_iter(self, runner):
self.before_iter(runner)
def after_train_iter(self, runner):
self.after_iter(runner)
def after_val_iter(self, runner):
self.after_iter(runner)
def every_n_epochs(self, runner, n):
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n):
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner):
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner):
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters
所有的hooks
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
'DistEvalHook', 'ProfilerHook'
]
-
hook嵌入入runner中
用一個priority queue存儲實例化的hook對象,用來保證hook調用的優先級。優先級定義如下:
"""Hook priority levels. +------------+------------+ | Level | Value | +============+============+ | HIGHEST | 0 | +------------+------------+ | VERY_HIGH | 10 | +------------+------------+ | HIGH | 30 | +------------+------------+ | NORMAL | 50 | +------------+------------+ | LOW | 70 | +------------+------------+ | VERY_LOW | 90 | +------------+------------+ | LOWEST | 100 | +------------+------------+ """
runner中有兩種hook注冊方式:
- register_hook
- register_hook_from_cfg
這兩個方法是在hook的基類mmcv/runner/base_runner.py中實現的,可以看到在register_hook中,倒序遍歷隊列,當找到一個比當前hook優先級高的hook時,就把當前的hook插入到這個hook的后面,如果找不到比它優先級高的就直接放在第一位。
def register_hook(self, hook, priority='NORMAL'): """Register a hook into the hook list. The hook will be inserted into a priority queue, with the specified priority (See :class:`Priority` for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Args: hook (:obj:`Hook`): The hook to be registered. priority (int or str or :obj:`Priority`): Hook priority. Lower value means higher priority. """ assert isinstance(hook, Hook) if hasattr(hook, 'priority'): raise ValueError('"priority" is a reserved attribute for hooks') priority = get_priority(priority) hook.priority = priority # insert the hook to a sorted list inserted = False for i in range(len(self._hooks) - 1, -1, -1): if priority >= self._hooks[i].priority: self._hooks.insert(i + 1, hook) inserted = True break if not inserted: self._hooks.insert(0, hook) def register_hook_from_cfg(self, hook_cfg): """Register a hook from its cfg. Args: hook_cfg (dict): Hook config. It should have at least keys 'type' and 'priority' indicating its type and priority. Notes: The specific hook class to register should not use 'type' and 'priority' arguments during initialization. """ hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = mmcv.build_from_cfg(hook_cfg, HOOKS) self.register_hook(hook, priority=priority)
-
runner中調用hook
在priority queue中按順序遍歷hooks,確保優先級。根據實現可以看出每次調用call_hook的時候整個隊列中的所有hook都會被調用到,並且執行自己實現的fn_name函數。
def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: getattr(hook, fn_name)(self)
-
Training前注冊hook
實例化runner對象后,會去注冊runner中用到的hooks
# register hooks runner.register_training_hooks(cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None)) if distributed: if isinstance(runner, EpochBasedRunner): runner.register_hook(DistSamplerSeedHook()) # register eval hooks if validate: # Support batch_size > 1 in validation val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) if val_samples_per_gpu > 1: # Replace 'ImageToTensor' to 'DefaultFormatBundle' cfg.data.val.pipeline = replace_ImageToTensor( cfg.data.val.pipeline) val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_dataloader = build_dataloader( val_dataset, samples_per_gpu=val_samples_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) # user-defined hooks if cfg.get('custom_hooks', None): custom_hooks = cfg.custom_hooks assert isinstance(custom_hooks, list), \ f'custom_hooks expect list type, but got {type(custom_hooks)}' for hook_cfg in cfg.custom_hooks: assert isinstance(hook_cfg, dict), \ 'Each item in custom_hooks expects dict type, but got ' \ f'{type(hook_cfg)}' hook_cfg = hook_cfg.copy() priority = hook_cfg.pop('priority', 'NORMAL') hook = build_from_cfg(hook_cfg, HOOKS) runner.register_hook(hook, priority=priority)
mmcv/base_runner.py
def register_training_hooks(self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None, momentum_config=None, timer_config=dict(type='IterTimerHook'), custom_hooks_config=None): """Register default and custom hooks for training. Default and custom hooks include: Hooks Priority - LrUpdaterHook 10 - MomentumUpdaterHook 30 - OptimizerStepperHook 50 - CheckpointSaverHook 70 - IterTimerHook 80 - LoggerHook(s) 90 - CustomHook(s) 50 (default) """ self.register_lr_hook(lr_config) self.register_momentum_hook(momentum_config) self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_timer_hook(timer_config) self.register_logger_hooks(log_config) self.register_custom_hooks(custom_hooks_config)
-
訓練推理中調用hook
def train(self, data_loader, **kwargs): self.model.train() self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook('before_train_iter') self.run_iter(data_batch, train_mode=True, **kwargs) self.call_hook('after_train_iter') self._iter += 1 self.call_hook('after_train_epoch') self._epoch += 1 @torch.no_grad() def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader self.call_hook('before_val_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook('before_val_iter') self.run_iter(data_batch, train_mode=False) self.call_hook('after_val_iter') self.call_hook('after_val_epoch')
Examples:
訓練的 self.call_hook('after_train_iter') 調用應該會發生backward更新梯度的操作,也就是說應該在optimizer_hook中有個實現的after_train_iter方法,實現了.backward()和optimizer.step()
@HOOKS.register_module() class OptimizerHook(Hook): def __init__(self, grad_clip=None): self.grad_clip = grad_clip def clip_grads(self, params): params = list( filter(lambda p: p.requires_grad and p.grad is not None, params)) if len(params) > 0: return clip_grad.clip_grad_norm_(params, **self.grad_clip) def after_train_iter(self, runner): runner.optimizer.zero_grad() runner.outputs['loss'].backward() if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) runner.optimizer.step()