PyTorch冻结已训练网络参数

在训练多层神经网络中,我们发现由于网络参数过多,网络收敛的条件有点苛刻。因此,分层训练的方式在日常生活中常常被用到。所谓的分层训练,顾名思义,即多层网络中,我们先训练好第一层网络,固定其参数,去训练第二层网络,当第二层网络训练完毕,就固定前两层参数,去训练第三层网络,以此类推。下面展现代码的实现方式。

# 网络模型
class Test(torch.nn.Module):
    def __init__(self,): 
        super(TwISTA, self).__init__()
        …………………
    def forward(self, y, max_itr):
		…………………
 

network = Test()   
  • 首先,搭建一层神经网络将其参数调到最优,并把网络参数保存下来。
torch.save(network.state_dict(), '/home/data_ssd/Test.pth')  # 保存到指定目录:/home/data_ssd/,文件名称格式:Test.pth
  • 其次,搭建下一层网络,调通、运行代码。用以下代码查看网络模型名称。
model_dict = network.state_dict()
for k, v in model_dict.items():  # 查看自己网络参数各层名称、数值
	print(k)  # 输出网络参数名字
    # print(v)  # 输出网络参数数值

运行结果:

fcs.0.thr
fcs.0.beta
fcs.0.Tw_alpha
fcs.0._W.weight
fcs.0._S.weight
fcs.1.thr
fcs.1.beta
fcs.1.Tw_alpha
fcs.1._W.weight
fcs.1._S.weight

我们可以看到,以“fcs.0”开头的参数是第一层网络,以“fcs.1”开头的参数是第二层网络。

  • 加载第一层已训练好的参数,关闭梯度。
pretrained_dict = torch.load('/home/data_ssd/Test.pth')  # 到相应目录加载刚刚保存的文件(网络参数)
model_dict['fcs.0.thr'] = pretrained_dict['fcs.0.thr']
model_dict['fcs.0.beta'] = pretrained_dict['fcs.0.beta']
model_dict['fcs.0.Tw_alpha'] = pretrained_dict['fcs.0.Tw_alpha']
model_dict['fcs.0._W.weight'] = pretrained_dict['fcs.0._W.weight']
model_dict['fcs.0._S.weight'] = pretrained_dict['fcs.0._S.weight']
# 第一层网络fcs.0不再参与训练,关闭梯度
for name, param in network.named_parameters():
    # print(name)
    if "fcs.0" in name:
        param.requires_grad = False
        
# 查看是否关闭成功
for name, param in network.named_parameters():
    if param.requires_grad:
        print("requires_grad: True ", name)
    else:
        print("requires_grad: False ", name)

运行结果:

requires_grad: False  fcs.0.thr
requires_grad: False  fcs.0.beta
requires_grad: False  fcs.0.Tw_alpha
requires_grad: False  fcs.0._W.weight
requires_grad: False  fcs.0._S.weight
requires_grad: True  fcs.1.thr
requires_grad: True  fcs.1.beta
requires_grad: True  fcs.1.Tw_alpha
requires_grad: True  fcs.1._W.weight
requires_grad: True  fcs.1._S.weight

从结果中可以看出,以“fcs.0”开头的参数(即第一层网络参数)均已关闭梯度。

  • 最后,在优化器中屏蔽掉第一层网络参数不再参与训练,仅仅训练新添加的网络参数。
opt = torch.optim.Adam(filter(lambda p: p.requires_grad, network.parameters()), lr=adam_lr)  # 过滤掉没有梯度的参数

以此类推,分别优化多层网络

  • 16
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值