mmcv閱讀筆記


mmcv

  • docs :文檔
  • example :一個訓練的例子
  • mmcv
    -- arraymisc :兩個函數(正則化和反正則化)
    ./mmcv.utils.registry.py 登記注冊類,很重要的模塊
class Registry:

    ***簡單的地方省略***
    
    def get(self, key):
        # 獲取存儲在字典中的類(模塊),在build的時候使用
        """Get the registry record.

        Args:
            key (str): The class name in string format.

        Returns:
            class: The corresponding class.
        """
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, module_name=None, force=False):
        # 將類(模型)加入到字典中
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, '
                            f'but got {type(module_class)}')

        if module_name is None:
            module_name = module_class.__name__
        if not force and module_name in self._module_dict:
            raise KeyError(f'{module_name} is already registered '
                           f'in {self.name}')
        self._module_dict[module_name] = module_class

    def deprecated_register_module(self, cls=None, force=False):
        # 裝飾器函數,將類(模型)加入到字典,同時返回當前的類(模型)
        warnings.warn(
            'The old API of register_module(module, force=False) '
            'is deprecated and will be removed, please use the new API '
            'register_module(name=None, force=False, module=None) instead.')
        if cls is None:
            return partial(self.deprecated_register_module, force=force)
        self._register_module(cls, force=force)
        return cls

    def register_module(self, name=None, force=False, module=None):
        """Register a module.

        A record will be added to `self._module_dict`, whose key is the class
        name or the specified name, and value is the class itself.
        It can be used as a decorator or a normal function.

        Example:
            方法一: 不帶名字加入,默認為模型的名字
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module()
            >>> class ResNet:
            >>>     pass
            方法二:帶名字的加入
            >>> backbones = Registry('backbone')
            >>> @backbones.register_module(name='mnet')
            >>> class MobileNet:
            >>>     pass
            方法三:手動加入到字典
            >>> backbones = Registry('backbone')
            >>> class ResNet:
            >>>     pass
            >>> backbones.register_module(ResNet)

        Args:
            name (str | None): The module name to be registered. If not
                specified, the class name will be used.
            force (bool, optional): Whether to override an existing class with
                the same name. Default: False.
            module (type): Module class to be registered.
        """
        if not isinstance(force, bool):
            raise TypeError(f'force must be a boolean, but got {type(force)}')
        # NOTE: This is a walkaround to be compatible with the old api,
        # while it may introduce unexpected bugs.
        if isinstance(name, type):
            return self.deprecated_register_module(name, force=force)

        # use it as a normal method: x.register_module(module=SomeClass)
        # 方法三的使用,直接加入字典
        if module is not None:
            self._register_module(
                module_class=module, module_name=name, force=force)
            return module

        # raise the error ahead of time
        if not (name is None or isinstance(name, str)):
            raise TypeError(f'name must be a str, but got {type(name)}')

        # use it as a decorator: @x.register_module()
        # 方法一和二的使用,裝飾器
        def _register(cls):
            self._register_module(
                module_class=cls, module_name=name, force=force)
            return cls

        return _register

-- cnn: 不同層define、registry、build
-- bricks: 基礎層(conv、relu、bn、padding..etc)
./mmcv.cnn.bricks.registry.py

from mmcv.utils import Registry
# 給每個層定義一個容器,相當於歸類管理方便
CONV_LAYERS = Registry('conv layer')
NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')

./mmcv.cnn.bricks.activation.py
以其中一個為代表舉例說明:

import torch.nn as nn

from mmcv.utils import build_from_cfg
from .registry import ACTIVATION_LAYERS

# 將activation class全部加入dict
for module in [
        nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
        nn.Sigmoid, nn.Tanh
]:
    ACTIVATION_LAYERS.register_module(module=module)

# 注意:registry只是登記操作,將類名(地址)寫入一個字典,相當於import
# build的作用是將registry實例化
def build_activation_layer(cfg):
    """Build activation layer.

    Args:
        cfg (dict): The activation layer config, which should contain:
            - type (str): Layer type.
            - layer args: Args needed to instantiate an activation layer.

    Returns:
        nn.Module: Created activation layer.
    """
    return build_from_cfg(cfg, ACTIVATION_LAYERS)

-- utils: 一個計算flops,一個init函數
-- 剩下三個alexnet、resnet、vgg實際模型,未使用mmcv里面的東西

-- fileio
handlers: pickle、yml等文件的讀取類
file_client.py: 0.6版本之后新增加的函數,主要是分布式的數據讀取,針對比較大的數據進行加速。
io.py: 集成handlers的數據讀取

-- image: 圖像相關的操作,讀取、處理、變換等

-- model_zero: 預訓練模型在線加載地址

-- ops: 需要快速操作的函數->NMS、ROIPooling、ROIAligning、sync—BN...etc,具體使用mmdetection再過來看。

-- parallel: 重新封裝了torch內部的並行計算,包括數據的collect、distribute、Scatter等,熟悉cuda的可以多了解。

-- runner: 包含hook和runner訓練相關的類(重點
-- hook: 訓練的參數封裝為Hook類,同時加入到registry之中

./mmcv/runner/hooks/hook.py

from mmcv.utils import Registry
# 全部的 hook 注冊字典
HOOKS = Registry('hook')

# Hook基礎子類,所有的hook都繼承此類,包括:logger、checkpoint、iter、lr...etc
class Hook:
    pass

hook的實現方式如下,其它相同:
./mmcv/runner/hooks/checkpoint.py


import os

from ..dist_utils import master_only
from .hook import HOOKS, Hook

# 將checkpoint的hook注冊進HOOKS之中
@HOOKS.register_module()
class CheckpointHook(Hook):
    @master_only # 多進程使用,僅保存rank=0的checkpoint
    def after_train_epoch(self, runner):
        pass

-- optimizer: 優化器模塊
./mmcv.runner.optimizer.builder.py

import copy
import inspect

import torch

from ...utils import Registry, build_from_cfg

OPTIMIZERS = Registry('optimizer') # torch中標准優化器
OPTIMIZER_BUILDERS = Registry('optimizer builder') #mmcv優化器封裝類,內部調用還是OPTIMIZERS,封裝之后不同的層使用不用的lr和momentum


def register_torch_optimizers():
    torch_optimizers = []
    for module_name in dir(torch.optim):
        if module_name.startswith('__'): # 排除其它不是優化器參數
            continue
        _optim = getattr(torch.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim,
                                                  torch.optim.Optimizer):
            OPTIMIZERS.register_module()(_optim) # 優化器加入registry
            torch_optimizers.append(module_name) # 存儲在字典中
    return torch_optimizers


TORCH_OPTIMIZERS = register_torch_optimizers() # 存儲所有優化器的字典


def build_optimizer_constructor(cfg):
    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)

# 創建優化器
def build_optimizer(model, cfg):
    optimizer_cfg = copy.deepcopy(cfg)
    constructor_type = optimizer_cfg.pop('constructor',
                                         'DefaultOptimizerConstructor')
    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
    optim_constructor = build_optimizer_constructor(
        dict(
            type=constructor_type,
            optimizer_cfg=optimizer_cfg,
            paramwise_cfg=paramwise_cfg))
    optimizer = optim_constructor(model)
    return optimizer

./mmcv.runner.optimizer.default_constructor.py

@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor:
    # 封裝之后的優化器,不同的層使用不同的lr和momentum
    def add_params(self, params, module, prefix=''):
        pass

./mmcv.runner.base_runner.py

class BaseRunner(metaclass=ABCMeta):

    def __init__(self,batch_processor):
        # batch_processor: 這是一個計算loss的函數,輸入已經固定(model, data, train_mode),輸出的loss是固定的在optimizer.py函數中after_train_iter進行反向傳播,如果有多個loss,可以修改batch_processor函數,或者修改after_train_iter中的反向傳播函數
        pass
        
    # 將hook登記進runner自帶的list之內self._hooks
    def register_hook(self, hook, priority='NORMAL'):
        pass
    
    # 調用指定的函數,注意:全部hook都得調用,無論有沒有這個功能
    def call_hook(self, fn_name):
        for hook in self._hooks:
            getattr(hook, fn_name)(self) # self表示當前runner
            
    # 加載預訓練模型
    def load_checkpoint(self, filename, map_location='cpu', strict=False): 
        pass
        
    # 恢復上一次訓練狀態
    def resume():
        pass
   
    # 加載傳入的config到hook
    def register_training_hooks(...):
        pass
    
    # 加載指定hook,register_training_hooks的實際操作
    def register_lr_hook(self, lr_config):
        pass
    
    #run train val下面詳細說明

./mmcv.runnner.epoch_base_runner.py


# 按照epoch進行訓練,還有一個函數是按照iter次數進行訓練
def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch') # 訓練epoch之前更新所有的Hook內部參數
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter') # 訓練iter之前更新所有的Hook內部參數
            if self.batch_processor is None:
                outputs = self.model.train_step(data_batch, self.optimizer,
                                                **kwargs) # 使用model自帶的loss計算函數
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=True, **kwargs) # 使用傳入的loss計算函數
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.train_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_train_iter') # 訓練iter之后更新所有的Hook內部參數,loss的更新也在一步,位於optimizer之中
            self._iter += 1

        self.call_hook('after_train_epoch') # 訓練epoch之后更新所有的Hook內部參數
        self._epoch += 1
        
    # 和train部分類似
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                if self.batch_processor is None:
                    outputs = self.model.val_step(data_batch, self.optimizer,
                                                  **kwargs)
                else:
                    outputs = self.batch_processor(
                        self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.val_step()"'
                                ' must return a dict')
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            self.outputs = outputs
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,說的很清楚
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode) # self.train() or self.val()
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs) # 調用self.train() or self.val()

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        else:
            meta.update(epoch=self.epoch + 1, iter=self.iter)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))



免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM