【pytorch】冻结网络踩坑

普通conv和fc层的冻结方式:

# 冻结参数
for i, p in enumerate(self.model.parameters()):
    if i <= 66:
        p.requires_grad = False


# 验证一下是否成功冻结参数
for k, v in self.model.named_parameters():
    print("k:{} v:{} ".format(k, v.requires_grad))

注意:model.parameters()都在梯度回传的更新过程中,所以可以用param.requires_grad = False的方式冻结,但是对于一些BN层的参数,比如BN层的runing_mean和runing_var,这两个值是前向计算统计得来的,并没有在梯度回传的更新过程中。所以,param.requires_grad=False对它们不起任何作用!

踩坑:

我的目的:在共用一个主干网络的多任务学习中,完全冻结其中一个表现较好的任务1分支,只训练其他两个任务:任务2分支和任务3分支。

结果:我以为用 “param.requires_grad=False” 的方式可以冻结任务1分支的所有参数,然后我发现我错了,冻结完,在验证过程中,我发现任务1的表现居然变差了。

验证:打印参数值,发现任务1的卷积层和全连接层参数不变(被成功冻结),只有BN层的runing_mean和runing_var发生了改变(未被冻结),应该就是他们的问题。

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值