Detectron2 代码解读 (1) 如何构造一个模型

Detectron2解读全部文章链接:

  1. Facebook计算机视觉开源框架Detectron2学习笔记 — 从demo到训练自己的模型

  2. Detectron2 “快速开始” Detection Tutorial Colab Notebook 详细解读

  3. Detectron2 官方文档详细解读 (上)

  4. Detectron2 官方文档详细解读(下)

  5. Detectron2 代码解读(1)如何构建模型

  6. Pytorch 基于 Detectron2 从零实现 Unet

Detectron2 代码解读(1)如何构造一个模型

读完官方文档之后对 Detectron2 已经有了基本了解。这个框架各个组件定义的非常完善,从创建模型到训练模型再到测试模型,每一步官方都提供了抽象,基本流程是这样的:

  1. 准备数据集 – 注册COCO格式数据集或者使用自定义结构数据集,注册 DatasetCatalog 和 MetadataCatalog,告诉模型如何提取你的数据。

  2. 数据集加载进入 Dataloader – 可以使用 build_detection_train_loader 和 build_detection_test_loader 快速创建dataloader,这个dataloader包含一个mapper,负责将输入的图片格式经过数据增强变为模型可以直接拿去 forward 的格式。可以使用 DefaultMapper或者自己实现 mapper。

  3. 创建模型,配置好 config 后,build_model(cfg) 直接根据 config 中定义的各个组件拼接模型。各个组件 Detectron2 也提供了 Registry 机制,你可以直接注册自己写的组件,之后在 config 中快速调用。

  4. 训练模型,可以直接使用 DefaultTrainer,继承自 SimpleTrainer,包含了训练常用的基础参数和操作。如果不能满足需求直接继承 DefaultTrainer 重写方法即可。

  5. 评测模型,可以直接使用 DatasetEvaluator 对模型性能进行评测。如果使用自定义数据集,可以直接继承 DatasetEvaluator 重写方法。

  6. 单张推理,可以直接使用 DefaultPredictor,如果不能满足需求,也可以继承重写方法。

基本的流程已经出来了,下面进入代码分析环节。作为刚入门的新人而说,一般不太需要一些不常见的 Dataloader 操作或者一些新任务,我们直接从创建模型这里开始分析。

直接从 tools/train_net.py 这里开始。

我们发现文件首先运行了 main(args)

在 main 函数里面,首先调用了一个 cfg = setup(args),之后创建了一个 Trainer 对象,然后执行了 trainer.train()。主要关键在于执行训练的这个方法 trainer.train(),我们下面具体分析 Trainer 类,它继承了 DefaultTrainer 这个类。

class Trainer(DefaultTrainer):

	@classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        # 创建 evaluator,返回一个 DatasetEvaluator
    
    @classmethod
    def test_with_TTA(cls, cfg, model):
        # 训练结束之后进行 Evaluate,只支持一些 R-CNN 模型。

Trainer 类里面没有训练的函数,那我们往上看,去看 DefaultTrainer 类,在 detectron2/engine/defaults.py 里面:

class DefaultTrainer(TrainerBase):
"""
这是一个训练器,它按照如下步骤执行:
1. 用给定的 模型,优化器,和 dataloader 创建一个 SimpleTrainer 对象,并且制定一个学习率schedule.
2. 从 cfg.MODEL.WEIGHTS 加载模型,如果 resume_or_load 被调用。
3. 注册一些 hooks
为了简化,DefaultTrainer 默认了很多流程,如果基于研究需要你需要一些更复杂的操作,你可以 1. 覆写 DefaultTrainer 类
2. 使用 SimpleTrainer 类,并写规划你自己的训练流程
3. 完全重新写一个 Trainer
你应该这么使用:
	trainer = DefaultTrainer(cfg)
	trainer.resume_or_load()
	trainer.train()
"""
	def __init__(self, cfg):
        # 部分不是很关键的代码我们直接跳过
        super().__init__()
        cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
        model = self.build_model(cfg) # 重点,根据 cfg 创建模型
        optimizer = self.build_optimizer(cfg, model) # 根据 cfg 创建优化器
        data_loader = self.build_train_loader (cfg) # 根据 cfg 创建 dataloader
        model = create_ddp_model(model, broadcast_buffers=False)
        # 这边是LR schedule 和 Checkpointer
        # 用来规划学习率,以及保存模型
        self.scheduler = self.build_lr_scheduler (cfg, optimizer)
        self.checkopinter = DetectionCheckpointer(
        	model, cfg.OUTPUT_DIR, trainer=weakref.proxy(self))
        # iteration 从0开始,从 MAX_ITER 结束
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg
        self.register_hooks(self.build_hooks())
    
    def resume_or_load(self, resume=True):
        # 从 cfg.OUTPUT_DIR 中读取最后保存的模型,读取状态意味着所有 LR
        # 优化器,iter 等状态都会被读取。
    def build_hooks(self):
        # 创建一个 hook 的列表,hook 我们之后单独开文章研究
    def build_writers(self):
        # 创建一个 writer 的列表,用于在 OUTPUT_DIR 保存训练信息
    
    def train(self):
        # 执行训练,继承了 SimpleTrainer 的 train() 方法
        super().train(self.start_iter, self.max_iter)
    
    @classmethod
    def build_model(cls, cfg):
        # 重点来了,这里返回一个 torch.nn.Module 的对象
        # 这里调用的 detectron2/modeling 的 build_model 方法
        model = build_model(cfg)
        return model
    # 这些方法都是类方法,看名字即知道作用,在这个文件里这些方法是没有定义的
    # 这里留出这些方法名的作用是,如果你需要自定义,你可以覆写这些方法
    def build_optimizer(cls, cfg, model)
    def build_lr_scheduler(cls, cfg, optimizer)
    def build_train_loader(cls, cfg, dataset_name)
    def build_test_loader(cls, cfg, dataset_name)
    def build_evaluator(cls, cfg, dataset_name)
    def test(cls, cfg, model, evaluators=None)
    # 如果训练途中 num_workers 发生变化,这里会自动调整 batch size
    def auto_scale_workers(cfg, num_workers: int)

到这里,我们发现创建模型使用的是一个 build_model 的方法,被定义在 detectron2/modeling 中。

进入 detectron2/modeling/meta_arch/build.py,我们发现了这个 build_model 方法:

def build_model(cfg):
    meta_arch = cfg.MODEL.META_ARCHITECTURE
    model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
    model.to(torch.device(cfg.MODEL.DEVICE))
    _log_api_usage("modeling.meta_arch." + meta_arch)
    return model

这里我们简单介绍一下 Detectron2 的 Registry 机制

Registry 的具体实现我们就略过了,简单来说,这个机制可以把字符串映射到对应的方法或类,简单来说:

registry = Registry('funcs')

# 把 custom_print(str) 方法注册到 'funcs' 中。
@registry.register()
def custom_print(str):
    print(str)
# 之后通过 'custom_print' 这个字符串即可对应到这个方法
registry.get('custom_print')('hello')

因此,这里的 META_ARCH_REGISTRY.get(meta_arch)(cfg) 也是同理,这里会根据 cfg.MODEL.META_ARCHITECTURE 创建一个 meta_architecture。这个 meta_architecture 包含了各个组件,每个组建又在 cfg 中被定义,举个例子,我们去看看 configs/Base-RCNN-FPN.yaml 这个 config 文件是怎么写的。一些暂时无关紧要的条目我们先略过,我们看到:

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  BACKBONE:
    NAME: "build_resnet_fpn_backbone"
  RESNETS:
    OUT_FEATURES: ["res2", "res3", "res4", "res5"]
  FPN:
    IN_FEATURES: ["res2", "res3", "res4", "res5"]
  ANCHOR_GENERATOR:
    SIZES: [[32], [64], [128], [256], [512]]
    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]
  RPN:
    IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
    PRE_NMS_TOPK_TRAIN: 2000
    PRE_NMS_TOPK_TEST: 1000
    POST_NMS_TOPK_TRAIN: 1000
    POST_NMS_TOPK_TEST: 1000
  ROI_HEADS:
    NAME: "StandardROIHeads"
    IN_FEATURES: ["p2", "p3", "p4", "p5"]
  ROI_BOX_HEAD:
    NAME: "FastRCNNConvFCHead"
    NUM_FC: 2
    POOLER_RESOLUTION: 7
  ROI_MASK_HEAD:
    NAME: "MaskRCNNConvUpsampleHead"
    NUM_CONV: 4
    POOLER_RESOLUTION: 14

OK,我们首先看到这个配置文件中,META_ARCHITECTURE 是 “GeneralizedRCNN”,我们随后找到 detectron2/modeling/meta_arch/rcnn.py 中 GeneralizedRCNN 这个类。

@META_ARCH_REGISTRY.register()
class GeneralizedRCNN(nn.Module):
    
    @configurable
    def __init__(self, *,
                 # 重点,这里是各个组件
                 backbone: Backbone,
                 proposal_generator: nn.Module,
                 roi_heads: nn.Module,
                 # 数据归一化用到的一些参数,暂不介绍
                 pixel_mean: Tuple[float],
                 pixel_std: Tuple[float],
                 # 图像输入格式,可视化的周期,暂不介绍
                 input_format: Optional[str] = None,
                 vis_period: int = 0):
        """
        参数: 
        	backbone:必须是一个 Backbone 对象,我们随后介绍
        	proposal_generator:使用 backbone 提取出的特征图创建 proposal
        	roi_heads:逐 ROI 计算
        	...
        """
        super().__init__()
        self.backbone = backbone
        self.proposal_generator = proposal_generator
        self.roi_heads = roi_heads
        # 其他成员略过
        # 到这里我们发现,GeneralizedRCNN 这个类的每一个组件都是拆开定义的
    
    # 从 config 创建
    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        return {
            "backbone": backbone,
            "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()),
            "roi_heads": build_roi_heads(cfg, backbone.output_shape()),
            "input_format": cfg.INPUT.FORMAT,
            "vis_period": cfg.VIS_PERIOD,
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
        }
    @property
    def device(self): return self.pixel_mean.device
    # 训练过程中可视化 proposals
    def visualize_training(self, batched_inputs, proposals)
    
    def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]]):
        # 网络的完整 forward 流程,输入 batched 图片,输出 list[dict] 包含了所有结果。
        # 如果不是训练,返回单张推理的结果
        if not self.training:
            return self.inference(batched_inputs)
        images = self.preprocess_image(batched_inputs) # 预处理
        # 如果有 ground truth 的话,保存到 gt_instances
        if "instances" in batched_inputs[0]:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        else:
            gt_instances = None
        # 从 backbone 获取特征图
        features = self.backbone(images.tensor)
        # 如果有 proposal_generator
        if self.proposal_generator is not None:
            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
        else:
            # 如果没有的话,必须要预先给定 proposals,同时 proposal loss 为空
            assert "proposals" in batched_inputs[0]
            proposals = [x["proposals"].to(self.device) for x in batched_inputs]
            proposal_losses = {}
        
		# 送入 ROI_heads
        _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
        
        if self.vis_period > 0:
            #...
        
        # 把不同部分的 loss 放入 losses,返回
        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        return losses
            
    def inference(self,
                 batched_inputs: Tuple[Dict[str, torch.Tensor]],
                 detected_instances: Optional[List[Instances]]=None,
                 do_postprocess: bool = True,
               	 )
   		# 测试时的推理,不求 Loss
        # Args: batched_inputs (list[dict]) 输入图片
        #	detected_instances (list[Instances] or None),如果图片中已存在物体,推理时会避免重复
        #   do_postprocess (bool) 是否进行后处理
        # Returns:
        #	如果 do_postproces 为 True,返回和 forward 一样
        #   否则返回 list[Instances] 检测到的物体
        assert not self.training
        
        images = self.preprocess_image(batched_inputs) # 预处理
        features = self.backbone(images.tensor) # 送入 backbone 得到特征图
        
        if detected_instances is None:
            if self.proposal_generator is not None:
                proposals, _ = self.proposal_generator(images, features, None)
                # 如果有 proposal_generator,得到 proposals
            else:
                assert "proposals" in batched_inputs[0]
                proposals = [x["proposals"].to(self.device) for x in batched_inputs]
            	# 否则必须提前指定 proposals,不然报错
                
            # 送入 ROI heads 得到输出结果
            results, _ = self.roi_heads(images, features, proposals, None)
        else:
            #...
        
        if do_postprocess:
            #...
        return results
    
    def preprocess_image(self, batched_inputs: Tuple[Dict[str, torch.Tensor]])
    @staticmethod
    def _postprocess(instances, batched_inputs: Tuple[Dict[str, torch.Tensor]], image_sizes):
        # 把检测到的物体缩放回图片大小   
        

到这里整个流程就清晰了,build_model 函数创建一个 meta_architecture,这个是一个大体的框架,之后这个 meta_architecture 根据 config 填入各个组件,构造整个模型。我们之后再来研究一下 backbone。

再回到这个 GeneralizedRCNN 类,我们发现 backbone 的构建是通过 backbone = build_backbone(cfg) 函数来构建的,我们在 detectron2/modeling/backbone/build.py 中找到这个函数。其它都忽略,我们发现 backbone 还是通过 BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) 来构建的,函数返回了一个 Backbone 对象,Backbone 类继承自 nn.Module,我们来研究一下这个 Backbone 类。

打开 detectron2/modeling/backbone/backbone.py,我们发现了这个类,这个类仅仅创建了一个抽象,其中的函数完全没有定义,结构如下:

class Backbone(nn.Module, metaclass=ABCMeta):
    
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def forward(self): pass
    @property
    def size_divisibility(self) -> int: return 0 # 某些要求输入图片必须被某尺寸整除
    def output_shape(self): # 输出的特征图的shape
        return {
            name: ShapeSpec(channels = self._out_feature_channels[name], 
                            stride=self._out_feature_strides[name])
            for name in self._out_features
        }

这里也就是说 init 和 forward 都需要我们自己去重写。比如 ResNet (没有FPN)这种backbone,去看一下 detectron2/modeling/resnet.py,我们发现了 ResNet 这个类,它继承了 Backbone 类并且重写了 init 和 forward 方法。具体代码可以自己查阅,这里没有必要放出来。

随后我们发现了 build_resnet_backbone (cfg, input_shape) 这个函数,上面有 @BACKBONE_REGISTRY.register() 修饰符,也就是说我们上面的 BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) 是调用的这个函数,随后它返回了一个 ResNet 对象。

之后我们去看 proposal generator,它在 meta architecture 中由 build_proposal_generator 函数创建,这个函数同样使用了 Registry 中的名字,目前 proposal generator, config 里没有具体写出,查看 detectron2/config/defaults.py 中的所有默认 config,我们发现了 MODEL.PROPOSAL_GENERATOR.NAME = “RPN”。注意,所有在 config 中没有明确标注的条目都使用这个文件中的默认内容。

具体可以查看 detectron2/modeling/proposal_generator/rpn.py 中的 RPN 类,这里不多介绍了,之后会开文详细介绍。

ROI Heads也是同理,想知道这些部件具体达成了什么作用,请参考 Faster RCNN 的论文,这是很经典的结构了。

  • 20
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
由于Detectron2一个非常灵活的深度学习框架,其训练代码可以因特定问题而异。但是,以下是Detectron2训练代码的基本模板: ```python from detectron2.config import get_cfg from detectron2.data.datasets import register_coco_instances from detectron2.engine import DefaultTrainer from detectron2.utils.logger import setup_logger # Configurations cfg = get_cfg() cfg.merge_from_file("configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") cfg.DATASETS.TRAIN = ("train_dataset_name",) cfg.DATASETS.TEST = () cfg.DATALOADER.NUM_WORKERS = 2 cfg.SOLVER.IMS_PER_BATCH = 2 cfg.SOLVER.BASE_LR = 0.00025 cfg.SOLVER.MAX_ITER = 3000 cfg.SOLVER.STEPS = (1000, 2000) cfg.SOLVER.GAMMA = 0.5 cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # Dataset registration register_coco_instances("train_dataset_name", {}, "path/to/train.json", "path/to/train/images") register_coco_instances("test_dataset_name", {}, "path/to/test.json", "path/to/test/images") # Logger setup setup_logger() # Training trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) trainer.train() ``` 在上面的代码中,需要注意以下几点: 1. `cfg.merge_from_file("configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")`指定了模型配置文件的位置和名称。需要指定与数据集相对应的模型配置文件。 2. `cfg.DATASETS.TRAIN`参数需要指定训练数据集的名称,可以与注册数据集时指定的名称相同,也可以不同。 3. `register_coco_instances()`函数用于将COCO格式的数据集注册到Detectron2中,需要指定数据集的名称、COCO格式的标注文件位置以及图像数据所在的文件夹路径。 4. 训练器(`trainer`)定义和启动后,可以使用`trainer.train()`方法运行训练。 以上代码仅供参考,具体的训练代码需要根据问题和数据集进行调整和修改。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值