改动模型后,加载部分预训练权重文件

本文介绍了如何通过加载部分预训练权重来改善自定义模型的泛化性能。作者在姿态估计任务中,将原本基于COCO数据集(17个关节)训练的Hrnet模型修改为适应15个关节的情况。由于数据量小,模型快速收敛但泛化差。通过只加载预训练模型中未修改层的权重,尤其是去掉最后一层全连接层的后两个元素,模型的泛化性显著增强。这一方法的核心是将权重文件视为字典,并根据新模型结构匹配并加载合适的键值对。
摘要由CSDN通过智能技术生成

加载部分预训练权重文件

最近在做姿态估计相关内容,需要将Hrnet模型修改,Hrnet是基于coco数据集训练的,coco数据集是17个关节点,而我需要的是15个关节点,在将数据集标好训练之后,发现由于数据量比较小,模型能够很快收敛,但是泛化性能极差,于是,就想着把之前的coco预训练权重文件拿出来一部分,对我自己的模型进行训练,果不其然,在使用部分预训练权重文件进行训练后,模型的泛化性有了很大的改善,现在分享给大家。

1首先我们需要明确权重文件的类型是什么:我们在使用pytorch进行模型训练的时候,最后的权重文件实际上是一个字典, 只不过是一个有序字典OrderedDict类 ,关于这个类的各种操作,请参考这篇博客, 里面已经说的很详细了OrderedDict

2在明确权重文件其实就是一个字典类的时候,那么我们就能了解,权重文件其实就是key+value,所谓key就是每一层的关键字,而value就是每一层的矩阵数据,下面以一份权重文件为例:

预训练权重文件

我们在加载预训练权重文件之后,发现就是一个字典,并且是一个有序字典,那么同样,我们可以打印出字典的关键字:

字典关键字

在了解上述操作过程以后,那么如果我们想要加载部分预训练权重文件就很简单啦。

首先,我们需要将我们实例化删改后的模型:

model = YOUR_changed_model(**)

其次,加载你删改后模型的state_dict()

model_state_dict= model.state_dict()

同样,model_state_dict()也是一个字典文件,因为我们已经改变了模型,但是我们改变的只是模型的一部分,换句话说,改变的只是权重文件字典中的某些keys或者values,而我们加载的部分权重文件其实就是在原来权重文件中没有修改的。以我的模型为例,我只是改变了模型的最后的全连接层,本来最后一维是17维,我需要的是15维,那么也就是说出去最后一层的预训练权重文件,我都是可以使用的,并且最后一维我可以使用预训练权重文件的前15维,因此,修改如下:

for i, (k, v) in enumerate(model_state.items()):
    if i < 1752:
        model_state[k] = pretrained_weights[k] 
    else:
        model_state[k] = pretrained_weights[k][:-2]
torch.save(model_state, 'best.pt')  # 保存权重文件

最后由于我的模型修改比较简单,所以,调取预训练权重文件也比较容易,但是核心思想是一致的,就是把权重文件看作是一个字典,在我们新的模型中添加原来权重文件中存在的key以及value。

预训练模型BERT的功能是通过对大规模文本数据进行预训练,学习到丰富的语言表示,从而能够在各种下游任务中进行微调,提供更好的语义理解和表达能力。\[1\] BERT的预训练过程包括两个阶段,首先是通过双向语言模型进行语言模型预训练,然后使用Fine-Tuning模式解决下游任务。\[3\] BERT模型的训练过程中主要是微调分类器,而不需要对BERT模型本身进行大幅度的改动。\[1\] BERT模型的缺点包括随机遮挡策略较粗犷,对硬件资源消耗较大,以及收敛速度较慢等。\[2\] 总的来说,BERT模型的功能是通过预训练和微调,提供更好的语言理解和表达能力,适用于各种自然语言处理任务。 #### 引用[.reference_title] - *1* [NLP专栏|图解 BERT 预训练模型!](https://blog.csdn.net/Datawhale/article/details/109476057)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [BERT预训练模型的演进过程!(附代码)](https://blog.csdn.net/weixin_41510260/article/details/101641415)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Begin,again

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值