pytorch载入预训练模型后,只想训练个别层怎么办?

标签: pytorch finetune
5人阅读 评论(0) 收藏 举报
分类:

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
    if name 满足某些条件:
        value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
    print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
                              {'params':model.decoder.parameters()}
                              ],
                              lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。
在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有paramslr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

参考:

  1. pytorch官方文档
  2. https://blog.csdn.net/u012759136/article/details/65634477
查看评论

用Perl语言进行Socket编程

用Perl语言进行Socket编程 发布时间:2001年3月28日 00:00    网络编程是一门神秘且复杂的艺术,当然也十分有趣。Perl语言提供了丰富的TCP/IP网络函数,所有这些函数都直接来...
  • ghj1976
  • ghj1976
  • 2001-08-19 15:06:00
  • 1245

pytorch 如何加载部分预训练模型

pretrained_dict = ...model_dict = model.state_dict() # 1. filter out unnecessary keys pretraine...
  • AMDS123
  • AMDS123
  • 2017-03-19 14:55:38
  • 9707

pytorch 使用预训练层

pytorch 使用预训练层将其他地方训练好的网络,用到新的网络里面pytorch 使用预训练层 加载预训练网络 加载新网络 更新新网络参数加载预训练网络1.原先已经训练好一个网络 AutoEncod...
  • zzw000000
  • zzw000000
  • 2017-11-23 09:45:13
  • 262

PyTorch中使用预训练的模型初始化网络的一部分参数

在预训练网络的基础上,修改部分层得到自己的网络,通常我们需要解决的问题包括: 1. 从预训练的模型加载参数 2. 对新网络两部分设置不同的学习率,主要训练自己添加的层 一. 加载参数的方法: ...
  • u012494820
  • u012494820
  • 2018-01-15 20:50:55
  • 868

pytorch学习笔记(十一):fine-tune 预训练的模型

pytorch : fine-tune torchvision 中预训练的模型torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-...
  • u012436149
  • u012436149
  • 2017-09-20 10:42:59
  • 2708

PyTorch学习系列(十五)——如何加载预训练模型?

torch.nn.Module对象有函数static_dict()用于返回包含模块所有状态的字典,包括参数和缓存。键是参数名称或者缓存名称。函数Module::load_state_dict(stat...
  • VictoriaW
  • VictoriaW
  • 2017-05-31 16:24:23
  • 9922

PyTorch(7)——模型的训练和测试、保存和加载

目录连接 (1) 数据处理 (2) 搭建和自定义网络 (3) 使用训练好的模型测试自己图片 (4) 视频数据的处理 (5) PyTorch源码修改之增加ConvLSTM层 (6) 梯度反向...
  • u011276025
  • u011276025
  • 2017-11-11 19:53:54
  • 1175

pytorch:在网络中添加可训练参数,修改预训练权重文件

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。 假如我们现在有一个简单的两层感知机网络: # -*- coding: utf-8 -*- import to...
  • qq_19672579
  • qq_19672579
  • 2018-01-29 16:26:20
  • 438

使用pytorch预训练模型分类与特征提取

pytorch应该是深度学习框架里面比较好使用的了,相比于tensorflow,mxnet。可能在用户上稍微少一点,有的时候出问题不好找文章。下面就使用pytorch预训练模型做分类和特征提取,pyt...
  • u010165147
  • u010165147
  • 2017-06-01 12:12:01
  • 5860

PyTorch预训练

前言最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃肿好多了。今天让我们来谈谈PyTorch的预训练,主要是自己...
  • u012759136
  • u012759136
  • 2017-03-24 17:12:16
  • 2668
    个人资料
    等级:
    访问量: 3462
    积分: 87
    排名: 151万+
    文章存档
    最新评论