Trainer解析
我們繼續Detectron2代碼閱讀筆記-(一)中的內容。
上圖畫出了detectron2
文件夾中的三個子文件夾(tools,config,engine)之間的關系。那么剩下的文件夾又是如何起作用的呢?
def main(args):
cfg = setup(args)
if args.eval_only:
...
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
if cfg.TEST.AUG.ENABLED:
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
return trainer.train()
build_*方法
我們從trainer = Trainer(cfg)
開始進一步了解。
Detectron2代碼閱讀筆記-(一)中已經提到過一連串的Trainer的繼承關系如下:
tools.train_net.Trainer->detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase
,而detectron2.engine.default.DefaultTrainer
在其__init__(self, cfg)
函數中定義了解析cfg。如下面代碼所示,cfg會作為參數倍若干個build_*
方法解析,得到解析后的model,optimizer,data_loader等。
from detectron2.modeling import build_model
class DefaultTrainer(SimpleTrainer):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
...
self.register_hooks(self.build_hooks())
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
下面我們以DefaultTrainer.build_model
為例來介紹注冊機制,該方法調用了detectron2/modeling/meta_arch/build_model.py
的build_model
函數,其源代碼如下:
from detectron2.utils.registry import Registry
META_ARCH_REGISTRY = Registry("META_ARCH")
META_ARCH_REGISTRY.__doc__ = """
def build_model(cfg):
"""
Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
- meta_arch = cfg.MODEL.META_ARCHITECTURE: 根據超參數獲得網絡結構的名字
- return META_ARCH_REGISTRY.get(meta_arch)(cfg):META_ARCH_REGISTRY是一個
Registry
類(這個在后面會詳細介紹),可以將這一行代碼拆成如下幾個步驟:
model = META_ARCH_REGISTRY.get(meta_arch)
return model(cfg)
注冊機制Registry
那么Registry
到底是什么呢?在分析源代碼之前我們先了解一下如何使用它,假如你想自己實現一個新的backbone網絡,那么你可以這樣做:
首先在detectron2中定義好如下(實際上已經定義了):
# detectron2/modeling/backbone/build.py
BACKBONE_REGISTRY = Registry('BACKBONE')
之后在你創建的新的文件下按如下方式創建你的backbone
# detectron2/modeling/backbone/your_backbone.py
from .build import BACKBONE_REGISTRY
# 方式1
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
# 方式2
class MyBackbone():
...
BACKBONE_REGISTRY.register(MyBackbone)
Registry
源代碼如下(有刪減):
class Registry(object):
def __init__(self, name):
self._name = name
self._obj_map = {}
def _do_register(self, name, obj):
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(name, self._name)
self._obj_map[name] = obj
def register(self, obj=None):
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name))
return ret
- 首先是
__init__
部分:self._name
則是你要注冊的名字,例如對於完整的模型而言,name一般取META_ARCH
。當然如果你需要自定義backbone網絡,你也可以定義一個Registry('BACKBONE')
self._obj_map
:其實就是一個字典。以模型為例,key就是你的模型名字,而value就是對應的模型類。這樣你在傳參時只需要修改一下模型名字就能使用不同的模型了。具體實現方法就是后面這幾個函數。
register
: 可以看到該方法定義了注冊的兩種方式,一種是當obj==None
的時候,使用裝飾器的方式注冊,另外一種就是直接將obj作為參數調用_do_register
進行注冊。_do_register
:真正注冊的函數,可以看到它首先會判斷name是否已經存在於self._obj_map
了。什么意思呢?還是以backbone為例,我們定義了一個BACKBONE_REGISTRY = Registry('BACKBONE')
,然后又定義了很多種backbone,而這些backbone都使用@BACKBONE_REGISTRY.register()
的方式注冊到了BACKBONE_REGISTRY._obj_map
中了,所以才取名為Registry
,還是蠻形象的吼。get
: 這個其實就是根據key值對字典進行取值。
Detectron2 整體代碼架構
雖然Detectron2還有很多部分沒有介紹到,但是源代碼分析到這應該對整體架構有了一定的理解了,具體的一些細節會在后續的文章中進行分析。現對Detectron2 整體代碼架構總結一下: