Pytorch 训练网络时 fix 住部分层参数

先粘贴一段非常好的pytorch forum回复,来自 https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088/2

I faced this just a few days ago, so I’m sure this code should be up to date. Here’s my answer for Resnet, but this answer can be used for literally any model.

The basic idea is that all models have a function model.children() which returns it’s layers. Within each layer, there are parameters (or weights), which can be obtained using .param() on any children (i.e. layer). Now, every parameter has an attribute called requires_grad which is by default True. True means it will be backpropagrated and hence to freeze a layer you need to set requires_grad to False for all parameters of a layer. This can be done like this -

model_ft = models.resnet50(pretrained=True)
ct = 0
for child in model_ft.children():
ct += 1
if ct < 7:
    for param in child.parameters():
        param.requires_grad = False

This freezes layers 1-6 in the total 10 layers of Resnet50. Hope this helps!

 

但是这样只是 requires_grad 设置成 false,最好还是filter掉这些参数,虽然他们不会被训练了

optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

 

参考:

https://zhuanlan.zhihu.com/p/34147880

https://blog.csdn.net/guotong1988/article/details/79739775

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值