模型的创建
接上面一篇,由于在一篇里面讲不完,所以把具体的模型的创建过程单独拿出来讲了。这里以FasterRCNN为例。知道一个就懂其他的模型创建过程了。
假如前面的你已经知道了,现在进入到了创建模型的函数:
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
通过registry类得到注册器,然后从注册器中根据key取出相应的类,得到该类后可以正常创建创建实例;
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict): # cfg是否是个字典
raise TypeError(f'cfg must be a dict, but got {
type(cfg)}')
if 'type' not in cfg: # cfg中必须有type字段
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {
cfg}\n{
default_args}')
if not isinstance(registry, Registry): #传进来的register必须是个Register类型的注册器实例
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {
type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {
type(default_args)}')
args = cfg.copy()
if default_args is not None: #{'test_cfg': None, 'train_cfg': None}
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type') # 得到type的类型,这里为CocoDataset
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type) #通过register的get方法得到类 ,可以看到返回的是个class
if obj_cls is None:
raise KeyError(
f'{
obj_type} is not in the {
registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {
type(obj_type)}')
try:
return obj_cls(**args) # 主要是进这里追溯源码,会进入到这个函数
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{
obj_cls.__name__}: {
e}')
mmdet.models.detectors.faster_rcnn.FasterRCNN,可以看到是继承了两阶段的模型,只需要传进去 backbone,head,neck等参数,就返回这个类的对象了。所以需要进去看看父类是如何构建的。如果需要添加父类以外的其他功能,可以在这个类里面写。
@DETECTORS.register_module()
class FasterRCNN(TwoStageDetector):
"""Implementation of `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_"""
def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
init_cfg=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
TwoStageDetector有继承了BaseDetector。