Pytorch载入部分参数并冻结

本文介绍了在PyTorch中如何加载部分模型参数,并在训练时冻结特定层。当网络层名称和结构不变时,加载参数相对简单。但如果引入的网络层名称发生修改,需要构建新的state_dict来匹配。此外,文章还提供了冻结参数的方法,通过设置优化器以过滤不需要更新的参数。
摘要由CSDN通过智能技术生成

参考资料

  1. pytorch 模型部分参数的加载
  2. Pytorch中,只导入部分模型参数的做法
  3. Correct way to freeze layers
  4. Pytorch自由载入部分模型参数并冻结
  5. pytorch冻结部分参数训练另一部分
  6. PyTorch更新部分网络,其他不更新
  7. Pytorch固定部分参数(只训练部分层)

加载部分参数

如果加载现有模型的所有参数,我们常使用的是代码如下:

torch.load(model.state_dict())

在训练过程中,我们常常会使用预训练模型,有时我们是在自己的模型中加入别人的某些模块,或者对别人的模型进行局部修改,这个时候再使用torch.load(model.state_dict()),就会出现类似这些的错误:RuntimeError: Error(s) in loading state_dict for Net:Missing key(s) in state_dict:xxx。出现这个错误就是某些参数缺失或者不匹配。

保持原来网络层的名称和结构不变

现有模型中引入的那部分网络结构的网络层的名称和结构保持不变,这时候加载参数的代码很简单。

# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.<
  • 4
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值