本系列为本菜鸡阅读mmdetection源代码的过程中所进行的一些查阅和思考,正文内容里将本菜鸡增添批注后的源代码奉上。如有不对之处,欢迎大家指正!
mmdetection的GitHub
mmcv的GitHub
builder.py
builder的定义文件mmdetection/mmdet/models/builder.py
from mmcv.utils import Registry, build_from_cfg
from torch import nn
# Registry是一个存放module的一个仓库。
# 以下部分就是建立backbone、neck、roi_extractor等七个仓库。
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
# build函数用于建立一个module,它的返回值是一个nn.Module。
# 整体思路就是将cfg(config)和它属于的仓库(registry)通过mmcv.utils.build_from_cfg建立module。
def build(cfg, registry, default_args=None):
"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list): # isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg # build_from_cfg()在后面进行讲解。
]
return nn.Sequential(*modules) # nn.Sequential(*modules)神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行。
else:
return build_from_cfg(cfg, registry, default_args)
# 通过调用build()函数建立backbone、neck、roi_extractor等七个modules。
def build_backbone(cfg):
"""Build backbone."""
return build(cfg, BACKBONES)
def build_neck(cfg):
"""Build neck."""
return build(cfg, NECKS)
def build_roi_extractor(cfg):
"""Build roi extractor."""
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
"""Build shared head."""
return build(cfg, SHARED_HEADS)
def build_head(cfg):
"""Build head."""
return build(cfg, HEADS)
def build_loss(cfg):
"""Build loss."""
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
mmcv.utils.build_from_cfg
函数定义位于mmcv/mmcv/utils/registry.py
中
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
# 判断输入的参数形式是否正确
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
# 将obj_cls设置为config的type对应的class。
args = cfg.copy() # dict.copy()用来返回一个字典的浅复制。
obj_type = args.pop('type') # 从args中删除type和其对应的值,obj_type等于type对应的值。
if is_str(obj_type):
obj_cls = registry.get(obj_type) # registry.get(obj_type)是提取名为obj_type的class赋给obj_cls。
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type): # 检查是否为类
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value) # dict.setdefault():如果字典中包含有给定键,则返回该键对应的值,否则返回为该键设置的值。
return obj_cls(**args)