系列文章目录
例如:第二章冻结主干网络进行微调
前言
在深度学习任务中,有时我们需要在岩有任务基础上对模型进行微调,这需要对网络的主干进行冻结,我在阅读帮助文档时发现,mmlab部分主干网络可以支持参数实现主干网络的冻结。
一、冻结主干网络
以HRNET为例:mmlab源码链接
这里面支持frozen_stages参数,HRNET中主干网络包含4个stages,因此该数值,最大设置为3(0-3)
CLASSmmseg.models.backbones.HRNet(extra, in_channels=3, conv_cfg=None, norm_cfg={‘requires_grad’: True, ‘type’: ‘BN’}, norm_eval=False, with_cp=False, frozen_stages=- 1, zero_init_residual=False, multiscale_output=True, pretrained=None, init_cfg=None)
1.主要参数
extra (dict) –Detailed configuration for each stage of HRNet. There must be 4 stages, the configuration for each stage must have 5 keys:
in_channels (int) – Number of input image channels. Normally 3.
conv_cfg (dict) – Dictionary to construct and config conv layer. Default: None.
norm_cfg (dict) – Dictionary to construct and config norm layer. Use BN by default.
norm_eval (bool) – Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False.
with_cp (bool) – Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False.
frozen_stages (int) – Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1.
zero_init_residual (bool) – Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: False.
multiscale_output (bool) – Whether to output multi-level features produced by multiple branches. If False, only the first level feature will be output. Default: True.
pretrained (str, optional) – Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional) – Initialization config dict. Default: None.
代码如下:
from mmseg.models import HRNet
import torch
extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256)))
self = HRNet(extra, in_channels=1)
self.eval()
inputs = torch.rand(1, 1, 32, 32)
level_outputs = self.forward(inputs)
for level_out in level_outputs:
print(tuple(level_out.shape))
总结
目前网络应该只支持主干网络进行冻结