Detectron2源码阅读-整体流程(1)

最近有一个想法,想使用Mask DINO这个模型,查了一下,他是基于Detectron2框架实现的,但是又需求对这个框架进行一些魔改,所以需要对这个框架的源码进行学习。

首先是跟着Detectron2的官方文档进行学习:官方文档

直接跳过安装等环节,首先看dataset。

Dataset

我们在torch中使用Dataset的时候,一般这么写:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        # 在这里对数据进行预处理、转换等操作
        return sample

# 创建数据集
data = [1, 2, 3, 4, 5]
dataset = CustomDataset(data)

# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

# 使用数据加载器遍历数据
for batch in dataloader:
    print(batch)

我们一般在Dataset方法中完成对数据的读取,并且使用__getitem__方法来根据索引取数据。

然而在Detectron2中,有所不同。

Detectron2中,首先对数据集进行注册,注册后,使用DatasetCatalog.get("my_dataset")方法来获得数据集字典注意,这时候仅仅是数据集字典,并没有读取数据集

数据集字典是一个由dict组合的List,List[Dict]

Detectron2 的标准数据集字典:

TaskFields
Commonfile_name, height, width, image_id
Instance detection/segmentationannotations
Semantic segmentationsem_seg_file_name
Panoptic segmentationpan_seg_file_name, segments_info

具体可见:https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html

每个字典,包含数据集中一张图像的所有信息

Dataloader

在detectron2中,提供了两个函数,build_detection_{train,test}_loader,然而,针对更高程度的自定义,我首先重写了这两个函数实现了我需要的新功能,并且自定义了mapper方法。

from detectron2.data import detection_utils as utils
 # Show how to implement a minimal mapper, similar to the default DatasetMapper
def mapper(dataset_dict):
    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    # can use other ways to read image
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    # See "Data Augmentation" tutorial for details usage
    auginput = T.AugInput(image)
    transform = T.Resize((800, 800))(auginput)
    image = torch.from_numpy(auginput.image.transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]
    return {
       # create the format that the model expects
       "image": image,
       "instances": utils.annotations_to_instances(annos, image.shape[1:])
    }
dataloader = build_detection_train_loader(cfg, mapper=mapper)

其中,读取数据并加载到内存中,对注释进行decoder,数据增强方法,在这一步中实现。

Model and Training loop

Detectron2的搭建模型和训练逻辑 关系比较密切。

Model

困难的是,在使用torch编写自己的模型的时候,通常只要继承nn.Module,然后专注于forward方法的构建。输入一般为图像的tensor,输出预测结果的tensor。

然而,Detectron2只提供了使用cfg方法搭建模型的接口,并且输入和输出有所不同。有时候我就只想使用我自己写的模型。

Detectron2的model通过outputs = model(inputs)调用,其中,inputs是list[dict]。这里,每个dict就是经过mapper后的dataloader的输出。

模型在运行中,分为training和eval模式,其中,training模式的输出是loss构成的字典,而eval模式输出字典组成的列表,其中字典的内容就是模型的预测结果。

详细内容可见文档。

Training loop

训练的循环在 tools文件夹中的train_net.py文件中有介绍

训练的逻辑通过Trainer类来管理,提供了build_evaluator,build_train_loader,build_lr_scheduler三个方法。分别用来评估模型精度、构建训练逻辑、构建学习率调度器

def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    return trainer.train()

在main方法中,实现模型的训练或推理过程

def invoke_main() -> None:
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

通过invok_main方法来实现模型的分布式训练。

然而,这里的Trainer类继承于DefaultTrainer类,我们实际使用的时候,更多的可能继承于SimpleTrainer。说到这,就需要深入研究一下detectron2中engine的部分。

engine由defaults.py hooks.py launch.py train_loop.py组成。

train_loop.py

train_loop.py由四个类组成,分别是HookBase, TrainBase, SimpleTrainer, AMPTrainer.

  • HookBase是一个基类,用于实现hook机制。提供了before_train, after_train, before_step, after_backward, after_step, state_dict的接口,但是没有具体实现。
  • TrainBase也是一个基类,定义的方法分为3类。(1)注册hook机制。(2)遍历hook list并执行。(3)实现train方法。
    def train(self, start_iter: int, max_iter: int):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger(__name__)
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
                # self.iter == max_iter can be used by `after_train` to
                # tell whether the training successfully finished or failed
                # due to exceptions.
                self.iter += 1
            except Exception:
                logger.exception("Exception during training:")
                raise
            finally:
                self.after_train()

从train方法的具体实现过程看,传入的参数分别是start_iter: int, max_iter: int。通过EventStorage(start_iter)来存储训练中需要记录的信息,通过run_step()方法实现训练过程。但是在TrainBase中,run_step()方法没有实现哦。

  • SimpleTrainer是TrainBase的子类
    它假定在每一步中,需要:(1)使用数据加载器中的数据计算损失。(2)使用上述损失计算梯度。(3)使用优化器更新模型。
    训练期间的所有其它任务(检查点、记录日志、评估、学习率调度)由Hook来管理。
    如果想要做比这更复杂的任务,可以要么继承 TrainerBase 并实现自己的 run_step,要么编写自己的训练循环。
    SimpleTrainer的init方法包括:
def __init__(
        self,
        model,
        data_loader,
        optimizer,
        gather_metric_period=1,
        zero_grad_before_forward=False,
        async_write_metrics=False,
    ):
  • AMPTrainer这里我们不做讲解。
  • 24
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值