detectron2源码阅读2---使用configurable装饰器来构建dataloader

前言

 本篇主要讲解detectron2是如何读取数据集并用dataloader进行包装的。一个目标检测模型往往包含众多参数,那么如何提取出对应数据集的参数呢?detectron2设计了configuable装饰器。因此,本文主要分析下读取过程。细节后续有空在写。

1、从train.py文件debug开始

  在介绍detectron2的engine中,默认的训练器是engine/defaults.py文件中的 类class DefaultTrainer(TrainerBase)。在其中初始化类中,有一个构建读取数据集的接口:

data_loader = self.build_train_loader(cfg)

而build_train_loader继续下挖一层:

    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg)

OK,继续挖…

@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):

okay,终于遇到今天的难点了,我们发现,该函数用@configurable装饰器进行了装饰,而且该装饰器后边还跟着一个from_config参数。现在不理解装饰器没关系,我们现在只需要知道被装饰器包装的函数: build_detection_train_loader记为orig_func. 此处易于后续理解,先记住!
装饰器的执行顺序就是先执行装饰的部分,然后在执行被装饰的函数orig_func。所以,继续debug,你会进入到configurable装饰器里。而且,记住configurable的第一个参数是"from_config = _train_loader_from_config"。这里,我可以先告诉你:_train_loader_from_config是一个函数,你现在不需要知道其具体内容,你现在只需要把它看成from_config 即可。

2、函数装饰器configurable

  configurable实现在detectron2/config/config.py文件中。我这里先贴下其部分源码:

def configurable(init_func=None, *, from_config=None):      # * 后面参数必须明示写出来

    if init_func is not None:          # 若指定了init_func则执行if条件语句
        @functools.wraps(init_func)
        def wrapped(self, *args, **kwargs):
            if _called_with_cfg(*args, **kwargs):
                explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
                init_func(self, **explicit_args)
            else:
                init_func(self, *args, **kwargs)
        return wrapped

    else:                             # 若没有指定,则执行else语句
        def wrapper(orig_func):       # 此时orig_func就是指被装饰的原始函数

            @functools.wraps(orig_func)
            def wrapped(*args, **kwargs):
                if _called_with_cfg(*args, **kwargs):
                    explicit_args = _get_args_from_config(from_config, *args, **kwargs)
                    return orig_func(**explicit_args)           
                else:
                    return orig_func(*args, **kwargs)
            return wrapped
        return wrapper

  此处结构是if -else, 区别就是指定了init_func参数。第一部分我们只指定了from_config参数,因此,我们只需要看else部分即可。此时你发现了代码结构:
def warpper(orig_func)
_get_args_from_config(from_config)
return wrapper
  其本质就是funA = funB(funA)。简单来说就是:此处funA是orig_func,之后在funB函数wrapped函数内内拓展了funA的一部分功能,比如该函数内部中间开小灶,调用了一个_get_args_from_config(from_config)函数并且返回了一个orig_func(**args),即返回funA。
  所以,到此你就可以猜出来:构建dataloader首先开小灶调用from_config函数构建了一个dataset类,之后在通过orig_func构建dataloader等。

3、合并

3.1 from_config函数

  第二部分你了解了装饰器内容,结合第一部分你记住的:orig_func 是build_detection_train_loade, from_config是_train_loader_from_config。这里看下_train_loader_from_config的代码:

def _train_loader_from_config(cfg, *, mapper=None, dataset=None, sampler=None):
    if dataset is None:
        dataset = get_detection_dataset_dicts(                  # 读取数据集
            cfg.DATASETS.TRAIN,
            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
            if cfg.MODEL.KEYPOINT_ON
            else 0,
            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
        )

    if mapper is None:
        mapper = DatasetMapper(cfg, True)

    if sampler is None:
        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
        logger = logging.getLogger(__name__)
        logger.info("Using training sampler {}".format(sampler_name))
        if sampler_name == "TrainingSampler":
            sampler = TrainingSampler(len(dataset))
        elif sampler_name == "RepeatFactorTrainingSampler":
            repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
                dataset, cfg.DATALOADER.REPEAT_THRESHOLD
            )
            sampler = RepeatFactorTrainingSampler(repeat_factors)
        else:
            raise ValueError("Unknown training sampler: {}".format(sampler_name))

    return {
        "dataset": dataset,
        "sampler": sampler,
        "mapper": mapper,
        "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
        "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
        "num_workers": cfg.DATALOADER.NUM_WORKERS,
    }

  该函数主要作用就是构建了dataset和sampler。即开小灶程序拓展的功能。

3.2 总的程序流程

 现在走一下调用流程,首先开小灶,程序执行了explicit_args = _get_args_from_config(from_config, *args, **kwargs)。我这里贴下代码:

def _get_args_from_config(from_config_func):      # !!!!!!!!!!!此处第一个参数即from_config
    if support_var_arg:  # forward all arguments to from_config, if from_config accepts them
        ret = from_config_func(*args, **kwargs)
    else:
        # forward supported arguments to from_config
        supported_arg_names = set(signature.parameters.keys())
        extra_kwargs = {}
        for name in list(kwargs.keys()):
            if name not in supported_arg_names:
                extra_kwargs[name] = kwargs.pop(name)
        ret = from_config_func(*args, **kwargs)            # !!!!!!!!!!!!!!!!!!!
        # forward the other arguments to __init__
        ret.update(extra_kwargs)
    return ret

  主要看代码中我加感叹号的部分,实质上该函数就是调用了3.1节中的函数,即开小灶去构建了一个dataset和sampler;之后在调用orig_func完成dataloader构建:

@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):

    if isinstance(dataset, list):
        dataset = DatasetFromList(dataset, copy=False)
    if mapper is not None:
        dataset = MapDataset(dataset, mapper)
    if sampler is None:
        sampler = TrainingSampler(len(dataset))
    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
    return build_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
    )

总结

  detectron2中用到configurable装饰器的地方不少。就不一一列举了,后续会介绍如何封装dataset,即如何开小灶的。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Detectron2中,要debug进入DataLoader内部,可以根据默认的训练器(DefaultTrainer)中的代码进行操作。在engine/defaults.py文件中,可以找到DefaultTrainer类的定义。其中,在初始化类的部分,可以看到以下代码:data_loader = self.build_train_loader(cfg) [1。 要debug进入DataLoader内部,可以在这一行代码之后添加断点,并运行代码。当代码执行到这个断点时,可以使用调试器进一步查看和调试DataLoader的内部实现。 另外,你也可以使用类似下面的代码在训练过程中查看DataLoader中的数据: ```python for x, y in train_loader: print(x, y) break ``` 这段代码会迭代训练集的第一个批次数据,并打印出来。这样你就可以看到DataLoader返回的数据的样子了 [2。 需要注意的是,以上操作针对的是Detectron2的默认训练器和数据加载方式。如果你在自定义训练器或者数据加载逻辑中,可能需要根据具体情况进行相应的调试操作 [3。123 #### 引用[.reference_title] - *1* *3* [detectron2源码阅读2---使用configurable装饰器构建dataloader](https://blog.csdn.net/wulele2/article/details/119081975)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}} ] [.reference_item] - *2* [系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)](https://blog.csdn.net/wuzhongqiang/article/details/105499476)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值