【pytorch】加载部分模型参数及冻结部分参数

前言

结论:1)部分加载模型参数的关键就是自定义选取训练好的模型的state_dict的键值对,然后更新到需要加载模型参数的新模型的state_dict中。 2)冻结部分参数的关键就是自定义设置需冻结的参数的requires_grad属性值为False,并在优化器中传入参数时,过滤掉requires_grad=False的这部分参数,使其不参与更新。 下文通过实例记录如何在pytorch中只加载部分参数,及冻结部分参数进行训练。
 
背景如下,有一个基于bert的内容分类器和基于bert的序列标注分类器,以下简称classifier与identifier。两个模型在特征抽取部分都采用相同的网络结构,就是bare_bert部分,但两种分类器的类型数量不一样,除此之外,bare_bert后续的网络结构也不同,最终需求为:“先训练classifier,然后将classifier的特征抽取部分bare_bert的参数共享给identifier,然后单独训练identifier。在训练identifier时,其共享的bare_bert的参数保持不变,只更新其层的网络参数。”
 

下面给了classifier与identifier的人造例子,有助于理解,classifier的定义如下:
在这里插入图片描述
classifier各层网络参数值为:
在这里插入图片描述

identifier定义如下:
在这里插入图片描述
identifier各层的参数值为:
在这里插入图片描述

可以发现,“fc1.weight, fc1.bias, fc2.weight, fc2.bias”是classifier与identifier都共有的,但具体值不相同,这里假设classifier是训练完之后的参数值,下面记录如何将共有的参数值共享给identifier。
 

共享部分参数

共享模型参数的核心工具就是模型的state_dict方法与load_state_dict方法。状态字典本质是python中的有序字典。

# *************** 自定义取出需要共享的参数 *******************
from collections import OrderedDict
temp = OrderedDict()

ide_state_dict = identifier.state_dict(destination=None)
for name, parameter in classifier.named_parameters():
	if name in ide_state_dict:
		temp[name] = parameter

# ************** 将共享的参数更新到需训练的模型中 ****************
ide_state_dict.update(temp)  # 更新参数值
identifier.load_state_dict(ide_state_dict)

此时再查看identifier的参数值,可以发现“fc1.weight, fc1.bias, fc2.weight, fc2.bias”部分的参数值已经是classifier的了,注意此时的“requires_grad=True”,下面记录如何冻结部分参数。
在这里插入图片描述

冻结部分参数

模型中的Parameter本质是Tensor的子类,因此其有Tensor的所有属性,其中“requires_grad”属性决定是否需要计算梯度,默认情况下,模型网络中的参数是需要记录梯度的,因此需要将“requires_grad”设置为False,并且在优化器中过滤掉这部分参数。如下所示:

# 自定义冻结部分参数
for name, parameter in identifier.named_parameters():
    if 'ide_only' not in name:
        parameter.requries_grad = False

# 过滤传入优化器的参数
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, identifier.parameters()))

 

参考资料

pytorch只加载预训练模型中的部分参数及冻结部分参数

Pytorch中,只导入部分模型参数的做法

How the pytorch freeze network in some layers, only the rest of the training?

pytorch冻结部分参数训练另一部分

  • 9
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值