冻结pytorch的网络有两种方式
1. 冻结方式一:lr=0
caffe反正是这么写的,具体方式如下
model_params = [{'params': base_params, 'lr': 0}, # 注释【1】
{'params': new_params, 'lr': cfg.SOLVER.BASE_LR * cfg.SOLVER.LR_MULTIPLE}]
# base_params和new_params都是提前设定好的
#然后初始化优化器
optimizer = torch.optim.SGD(
model_params,
lr=solver.BASE_LR * lr_multi,
momentum=solver.MOMENTUM,
weight_decay=solver.WEIGHT_DECAY
)
【1】干脆不设置这一组也是可以的
2. 冻结方式二:require_grad=False
pytorch和mxnet都提供了相似的接口
for p in model.named_parameters():
if p[0] in match_layers:
p[1].requires_grad = False
3. 补充:只冻结主干网络的方法
#假设主干网络是model.module.mobilenet,backbone=mobilenet
try:
sub_model = eval('model.module.' + backbone)
#也等同于 model.module.mobilenet.eval()
sub_model.eval()
except Exception as e:
pass