当我们开始学习使用mmdetection的时候,会遇到类似这样的代码,在深入学习过程中,这个样式的代码,我们发现越来越多,今天我们就以这个为起点深入研究一下。
@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):
定位声明DATASETS的位置
DATASETS = Registry('dataset')
mmdetection的核心组件之一就是它的注册器机制Registry,Registry建立了一个映射:从字符串’KittiTinyDataset’到class KittiTinyDataset。Registry维护了一个dict,其中key是字符串,value是class。并且考虑到扩展性,Registry用了python装饰器机制来方便地注册新的组件。
class Registry:
# name-注册模块名称,build_func-从注册模块中构建实例的函数名,parent-Parent registry,scope-The scope of registry注册范围
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
下面就是我们用来装饰类的装饰器了,可以直接调用该函数注册类,也可以通过装饰器的方式。
def register_module(self, name=None, force=False, module=None):
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
通过下面这个函数,把类名和类添加到 self._module_dict中
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
注册不是我们的目的,我们的目的使用,下面是我们的调用过程
datasets = [build_dataset(cfg.data.train)]
build_dataset包含了数据处理过程
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset, MultiImageMixDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'MultiImageMixDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
cp_cfg.pop('type')
dataset = MultiImageMixDataset(**cp_cfg)
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
最终的构建过程是在build_from_cfg中完成的,return obj_cls(**args) 根据config文件实例化注册的类。
def build_from_cfg(cfg, registry, default_args=None):
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')
自定义的数据类继承CustomDataset,obj_cls(**args)在CustomDataset的__init__函数中进行初始化。
class KittiTinyDataset(CustomDataset):
pass
@DATASETS.register_module()
class CustomDataset(Dataset):
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
file_client_args=dict(backend='disk')):
这样我们的数据集就生成成功了。