mmdetection模型構建及Registry注冊器機制


        好久沒有做目標檢測了,最近突然又接到了檢測任務,跟同事討論時,發現自己竟然忘了很多細節,

於是想趁訓練模型的間隙,重新梳理下目標檢測。我選擇了mmdetection來學習,除了目標檢測本身,

這個框架中很多python的使用技巧和框架的設計模式也是值得學習。最近一年基本都在使用python,

希望能將這些技巧應用在以后的工作之中。mmdetection封裝的很好方便使用,比如我想訓練的

話只需如下的一條指令。在train.py中,通過build_detector來構建模型(參數來自 faster_rcnn_r50_fpn_1x_voc0712.py),

python tools/train.py  configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py

build_detector的定義如下,最后通過build_from_cfg來構建模型,這里看到了讓人困惑的Registry.

from mmdet.cv_core.utils import Registry, build_from_cfg from torch import nn BACKBONES = Registry('backbone') NECKS = Registry('neck') ROI_EXTRACTORS = Registry('roi_extractor') SHARED_HEADS = Registry('shared_head') HEADS = Registry('head') LOSSES = Registry('loss') DETECTORS = Registry('detector') def build(cfg, registry, default_args=None): """Build a module. Args: cfg (dict, list[dict]): The config of modules, is is either a dict or a list of configs. registry (:obj:`Registry`): A registry the module belongs to. default_args (dict, optional): Default arguments to build the module. Defaults to None. Returns: nn.Module: A built nn module. """
    if isinstance(cfg, list): modules = [ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg ] return nn.Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) def build_detector(cfg, train_cfg=None, test_cfg=None): """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

       

一、Registry是干什么的

        Registry完成了從字符串到類的映射,這樣模型信息、訓練時的參數信息,只需要寫入到一個配置文件里,然后使用注冊器來實例化即可。

二、如何實現

        通過裝飾器來實現。在mmcv/mmcv/registry.py中,我們看到了Registry類。其中完成字符串到類的映射,實際上就是下面的成員函數來實現的,核心代碼就一句,將要注冊的類添加到字典里,key為類的名字(字符串)。下面通過一個小例子,

 def _register_module(self, module_class, module_name=None, force=False): if not inspect.isclass(module_class): raise TypeError('module must be a class, ' f'but got {type(module_class)}') if module_name is None: module_name = module_class.__name__
        if not force and module_name in self._module_dict: raise KeyError(f'{module_name} is already registered ' f'in {self.name}') self._module_dict[module_name] = module_class

 來看看它的構建過程。在導入下面這個文件時,首先創建FRUIT實例,接着通過裝飾器(這里是用成員函數裝飾類)來注冊Apple類,調用register_module,然后調用_register(注意:參數cls即為類Apple),最后調用_register_module完成Apple的添加。完成后,FRUIT就有了個字典成員:['Apple']=APPle。在build_from_cfg中,傳入模型參數,即可通過FRUIT構建Apple的實例化對象。

class Registry(): def __init__(self, name): self._name = name self._module_dict = dict() def _register_module(self, module_class, module_name, force): self._module_dict[module_name] = module_class def register_module(self, name=None, force=False, module=None): print('register module ...') def _register(cls): print('cls ', cls) self._register_module( module_class=cls, module_name=name, force=force) return cls return _register FRUIT = Registry('fruit') @FRUIT.register_module() class Apple(): def __init__(self, name): self.name = name

def build_from_cfg(cfg, registry, default_args=None):
   

    args = cfg.copy()

    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)

    obj_type = args.pop('type')
    if is_str(obj_type):
        obj_cls = registry.get(obj_type)
    
    return obj_cls(**args)

三、Registry在mmdetection中是如何構建模型的

          我們來看一下構建模型的流程:

        1、在train.py中通過build_detector構建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和訓練信息。

model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 

         

        2、最關鍵的部分來了。首先通過build_detector構建模型, 其中傳入的DETECTORS是Registry的實例,在該實例中,包含了所有已經實現的檢測器,如圖。那么它是在哪里實現添加這些檢測的類的呢?

def build_detector(cfg, train_cfg=None, test_cfg=None): """Build detector."""
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

 

           看了前面那個小例子我們就能猜到,一定是在這些檢測類上,用Registry對其進行了注冊,看看faster rcnn的實現,證明了我們的猜想。這樣只要

在定義這些類時,對其進行注冊,那么就會自動加入到DETECTORS這個實例的成員字典里,非常的巧妙。當我們想實例化某個檢測網絡時,傳入其字符名稱

即可。

       既然都看到這里了,就進一步看看網絡時如何繼續構建的吧。mmdetection將網絡分成了幾個部分,backbone,head,neck等。在TwoStageDetector(

faster rcnn的基類)中,可以看到分別構建了這幾個部分。head, neck, loss等,同樣是通過Registry來注冊實現的。最后就是將這幾個部分組合起來即可。

@DETECTORS.register_module() class TwoStageDetector(BaseDetector): """Base class for two-stage detectors. Two-stage detectors typically consisting of a region proposal network and a task-specific regression head. """

    def __init__(self, backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(TwoStageDetector, self).__init__() self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) if rpn_head is not None: rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head_ = rpn_head.copy() rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head_) if roi_head is not None: # update train and test cfg here for now
            # TODO: refactor assigner & sampler
            rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) self.roi_head = build_head(roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg self.init_weights(pretrained=pretrained)

 

四、Registry的應用

          在我最近的一個數據處理的項目中,有三類數據,sample, measure 和image。如果我想得到某個數據類型的實例,我是通過if來

判斷的。那如果數據類別很多呢?就像檢測器這樣有幾十種,再用if就顯得很蠢了。借用Registry機制,可以輕松解決這個問題。

 

 

 

 

        

 

              

      

       

      


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM