mmlab之调用mmpretrain预训练模型

系列文章目录



前言

MMLab的MMPretrain中集成了主流的自监督、对比学习模型,通过对现有数据进行预训练,来提升下游任务的性能。这里对下游任务如何调用已经训练好的预训练权重模型进行介绍,原文链接可参考帮助文档


一、示例1

使用在 MMPretrain 中实现的骨干网络
假设想将 MobileNetV3-small 作为 RetinaNet 的骨干网络,则配置文件如下

代码如下:

_base_ = [
    '../_base_/models/retinanet_r50_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# please install mmpretrain
# import mmpretrain.models to trigger register_module in mmpretrain
custom_imports = dict(imports=['mmpretrain.models'], allow_failed_imports=False)
pretrained = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth'
model = dict(
    backbone=dict(
        _delete_=True, # 将 _base_ 中关于 backbone 的字段删除
        type='mmpretrain.MobileNetV3', # 使用 mmpretrain 中的 MobileNetV3
        arch='small',
        out_indices=(3, 8, 11), # 修改 out_indices
        init_cfg=dict(
            type='Pretrained',
            checkpoint=pretrained,
            prefix='backbone.')), # mmpretrain 中骨干网络的预训练权重含义 prefix='backbone.',为了正常加载权重,需要把这个 prefix 去掉。
    # 修改 in_channels
    neck=dict(in_channels=[24, 48, 96], start_level=0))

二、示例2

通过 MMPretrain 使用 TIMM 中实现的骨干网络
由于 MMPretrain 提供了 PyTorch Image Models (timm) 骨干网络的封装,用户也可以通过 MMPretrain 直接使用 timm 中的骨干网络。假设想将 作为 RetinaNet 的骨干网络,则配置文件如下。

代码如下:

# https://github.com/open-mmlab/mmdetection/blob/main/configs/timm_example/retinanet_timm_efficientnet_b1_fpn_1x_coco.py
_base_ = [
    '../_base_/models/retinanet_r50_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# please install mmpretrain
# import mmpretrain.models to trigger register_module in mmpretrain
custom_imports = dict(imports=['mmpretrain.models'], allow_failed_imports=False)
model = dict(
    backbone=dict(
        _delete_=True, # 将 _base_ 中关于 backbone 的字段删除
        type='mmpretrain.TIMMBackbone', # 使用 mmpretrain 中 timm 骨干网络
        model_name='efficientnet_b1',
        features_only=True,
        pretrained=True,
        out_indices=(1, 2, 3, 4)), # 修改 out_indices
    neck=dict(in_channels=[24, 40, 112, 320])) # 修改 in_channels

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

type=‘mmpretrain.TIMMBackbone’ 表示在 MMDetection 中使用 MMPretrain 中的 TIMMBackbone 类,并且使用的模型为 EfficientNet-B1,其中 mmpretrain 表示 MMPretrain 库,而 TIMMBackbone 表示 MMPretrain 中实现的 TIMMBackbone 包装器。

三、示例3

如果你在配置文件中已经冻结了骨干网络并希望在几个训练周期后解冻它,你可以通过 hook 来实现这个功能。以用 ResNet 为骨干网络的 Faster R-CNN 为例,你可以冻结一个骨干网络的一个层并在配置文件中添加如下 custom_hooks:
代码如下(示例):

_base_ = [
    '../_base_/models/faster-rcnn_r50_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
    # freeze one stage of the backbone network.
    backbone=dict(frozen_stages=1),
)
custom_hooks = [dict(type="UnfreezeBackboneEpochBasedHook", unfreeze_epoch=1)]

同时在 mmdet/core/hook/unfreeze_backbone_epoch_based_hook.py 当中书写 UnfreezeBackboneEpochBasedHook 类

from mmengine.model import is_model_wrapper
from mmengine.hooks import Hook
from mmdet.registry import HOOKS


@HOOKS.register_module()
class UnfreezeBackboneEpochBasedHook(Hook):
    """Unfreeze backbone network Hook.

    Args:
        unfreeze_epoch (int): The epoch unfreezing the backbone network.
    """

    def __init__(self, unfreeze_epoch=1):
        self.unfreeze_epoch = unfreeze_epoch

    def before_train_epoch(self, runner):
        # Unfreeze the backbone network.
        # Only valid for resnet.
        if runner.epoch == self.unfreeze_epoch:
            model = runner.model
            if is_module_wrapper(model):
                model = model.module
            backbone = model.backbone
            if backbone.frozen_stages >= 0:
                if backbone.deep_stem:
                    backbone.stem.train()
                    for param in backbone.stem.parameters():
                        param.requires_grad = True
                else:
                    backbone.norm1.train()
                    for m in [backbone.conv1, backbone.norm1]:
                        for param in m.parameters():
                            param.requires_grad = True

            for i in range(1, backbone.frozen_stages + 1):
                m = getattr(backbone, f'layer{i}')
                m.train()
                for param in m.parameters():
                    param.requires_grad = True

总结

官方帮助文档已经写得很清楚了,帮助文档内容有点多,熟悉一下帮助文档的整体结构,关键的部分看的时候尽量把整个帮助文档都扫一遍吧,有个印象更好找一些,具体函数api这些用的时候再找就好。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值