mmlab之模型冻结训练

系列文章目录

例如:第二章冻结主干网络进行微调



前言

在深度学习任务中,有时我们需要在岩有任务基础上对模型进行微调,这需要对网络的主干进行冻结,我在阅读帮助文档时发现,mmlab部分主干网络可以支持参数实现主干网络的冻结。


一、冻结主干网络

以HRNET为例:mmlab源码链接
这里面支持frozen_stages参数,HRNET中主干网络包含4个stages,因此该数值,最大设置为3(0-3)
CLASSmmseg.models.backbones.HRNet(extra, in_channels=3, conv_cfg=None, norm_cfg={‘requires_grad’: True, ‘type’: ‘BN’}, norm_eval=False, with_cp=False, frozen_stages=- 1, zero_init_residual=False, multiscale_output=True, pretrained=None, init_cfg=None)

1.主要参数

extra (dict) –Detailed configuration for each stage of HRNet. There must be 4 stages, the configuration for each stage must have 5 keys:
in_channels (int) – Number of input image channels. Normally 3.

conv_cfg (dict) – Dictionary to construct and config conv layer. Default: None.

norm_cfg (dict) – Dictionary to construct and config norm layer. Use BN by default.

norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.

with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.

frozen_stages (int) – Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1.

zero_init_residual (bool) – Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: False.

multiscale_output (bool) – Whether to output multi-level features produced by multiple branches. If False, only the first level feature will be output. Default: True.

pretrained (str, optional) – Model pretrained path. Default: None.

init_cfg (dict or list[dict], optional) – Initialization config dict. Default: None.

代码如下:

from mmseg.models import HRNet
import torch
extra = dict(
    stage1=dict(
        num_modules=1,
        num_branches=1,
        block='BOTTLENECK',
        num_blocks=(4, ),
        num_channels=(64, )),
    stage2=dict(
        num_modules=1,
        num_branches=2,
        block='BASIC',
        num_blocks=(4, 4),
        num_channels=(32, 64)),
    stage3=dict(
        num_modules=4,
        num_branches=3,
        block='BASIC',
        num_blocks=(4, 4, 4),
        num_channels=(32, 64, 128)),
    stage4=dict(
        num_modules=3,
        num_branches=4,
        block='BASIC',
        num_blocks=(4, 4, 4, 4),
        num_channels=(32, 64, 128, 256)))
self = HRNet(extra, in_channels=1)
self.eval()
inputs = torch.rand(1, 1, 32, 32)
level_outputs = self.forward(inputs)
for level_out in level_outputs:
    print(tuple(level_out.shape))

总结

目前网络应该只支持主干网络进行冻结

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SatVision-RS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值