Detectron2源碼閱讀筆記-(二)Registry&build_*方法


Trainer解析

我們繼續Detectron2代碼閱讀筆記-(一)中的內容。

上圖畫出了detectron2文件夾中的三個子文件夾(tools,config,engine)之間的關系。那么剩下的文件夾又是如何起作用的呢?


def main(args):
    cfg = setup(args)

    if args.eval_only:
		...
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()

build_*方法

我們從trainer = Trainer(cfg)開始進一步了解。

Detectron2代碼閱讀筆記-(一)中已經提到過一連串的Trainer的繼承關系如下:
tools.train_net.Trainer->detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase,而detectron2.engine.default.DefaultTrainer在其__init__(self, cfg)函數中定義了解析cfg。如下面代碼所示,cfg會作為參數倍若干個build_*方法解析,得到解析后的model,optimizer,data_loader等。

from detectron2.modeling import build_model
class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)
		
		... 
		
        self.register_hooks(self.build_hooks())
		
	@classmethod
    def build_model(cls, cfg):
        """
        Returns:
            torch.nn.Module:
        """
        model = build_model(cfg)
        logger = logging.getLogger(__name__)
        logger.info("Model:\n{}".format(model))
        return model

下面我們以DefaultTrainer.build_model為例來介紹注冊機制,該方法調用了detectron2/modeling/meta_arch/build_model.pybuild_model函數,其源代碼如下:

from detectron2.utils.registry import Registry

META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
def build_model(cfg):
    """
    Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`.
    """
    meta_arch = cfg.MODEL.META_ARCHITECTURE
    return META_ARCH_REGISTRY.get(meta_arch)(cfg)
  • meta_arch = cfg.MODEL.META_ARCHITECTURE: 根據超參數獲得網絡結構的名字
  • return META_ARCH_REGISTRY.get(meta_arch)(cfg):META_ARCH_REGISTRY是一個Registry類(這個在后面會詳細介紹),可以將這一行代碼拆成如下幾個步驟:
model = META_ARCH_REGISTRY.get(meta_arch)
return model(cfg)

注冊機制Registry

那么Registry到底是什么呢?在分析源代碼之前我們先了解一下如何使用它,假如你想自己實現一個新的backbone網絡,那么你可以這樣做:

首先在detectron2中定義好如下(實際上已經定義了):

# detectron2/modeling/backbone/build.py
BACKBONE_REGISTRY = Registry('BACKBONE')

之后在你創建的新的文件下按如下方式創建你的backbone

# detectron2/modeling/backbone/your_backbone.py
from .build import BACKBONE_REGISTRY

# 方式1
@BACKBONE_REGISTRY.register()
class MyBackbone():
	...
		
# 方式2
class MyBackbone():
	...
BACKBONE_REGISTRY.register(MyBackbone)

Registry源代碼如下(有刪減):

class Registry(object):
    def __init__(self, name):
        self._name = name
        self._obj_map = {}

    def _do_register(self, name, obj):
        assert (
            name not in self._obj_map
        ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
        self._obj_map[name] = obj

    def register(self, obj=None):
        if obj is None:
            # used as a decorator
            def deco(func_or_class):
                name = func_or_class.__name__
                self._do_register(name, func_or_class)
                return func_or_class

            return deco

        # used as a function call
        name = obj.__name__
        self._do_register(name, obj)

    def get(self, name):
        ret = self._obj_map.get(name)
        if ret is None:
            raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))
        return ret
  • 首先是__init__部分:
    • self._name則是你要注冊的名字,例如對於完整的模型而言,name一般取META_ARCH。當然如果你需要自定義backbone網絡,你也可以定義一個Registry('BACKBONE')
    • self._obj_map:其實就是一個字典。以模型為例,key就是你的模型名字,而value就是對應的模型類。這樣你在傳參時只需要修改一下模型名字就能使用不同的模型了。具體實現方法就是后面這幾個函數。
  • register: 可以看到該方法定義了注冊的兩種方式,一種是當obj==None的時候,使用裝飾器的方式注冊,另外一種就是直接將obj作為參數調用_do_register進行注冊。
  • _do_register:真正注冊的函數,可以看到它首先會判斷name是否已經存在於self._obj_map了。什么意思呢?還是以backbone為例,我們定義了一個BACKBONE_REGISTRY = Registry('BACKBONE'),然后又定義了很多種backbone,而這些backbone都使用@BACKBONE_REGISTRY.register()的方式注冊到了BACKBONE_REGISTRY._obj_map中了,所以才取名為Registry,還是蠻形象的吼。
  • get: 這個其實就是根據key值對字典進行取值。

Detectron2 整體代碼架構

雖然Detectron2還有很多部分沒有介紹到,但是源代碼分析到這應該對整體架構有了一定的理解了,具體的一些細節會在后續的文章中進行分析。現對Detectron2 整體代碼架構總結一下:



微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯系~
郵箱:marsggbo@foxmail.com





2019-10-15 13:16:32




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM