深度学习加载预训练权重时冻结网络的部分参数

该文介绍如何在PyTorch中加载预训练模型权重,只保留与当前模型匹配的部分参数,并冻结除最后全连接层外的所有权重,用于后续的微调或特定层训练。使用`torch.load`加载权重,然后根据模型结构选择性地设置参数的`requires_grad`属性为False来冻结权重,只让最后的全连接层的参数参与训练。
摘要由CSDN通过智能技术生成

        这里以除了最后的全连接层,冻结网络其他参数为例:

weigth_path = './net.pth'
weights_dict = torch.load(weight_path, map_location=device)

# 只保留和模型参数个数一个的预训练参数块
load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
# 加载权重
model.load_state_dict(load_weights_dict, strict=False)
    
for name, para in model.named_parameters():
    # 除最后的全连接层外,其他权重全部冻结,注意这里的fc是在定义模型时命名的一个块
    if "fc" not in name:
       para.requires_grad_(False)

# 这个变量保存所有训练的参数,是供之后优化器使用的
pg = [p for p in model.parameters() if p.requires_grad]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值