pytorch载入预训练模型后,训练指定层

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)
 
 

    strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

    2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

    # 载入预训练模型参数后...
    for name, value in model.named_parameters():
        if name 满足某些条件:
            value.requires_grad = False
    
    # setup optimizer
    params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(params, lr=1e-4)
     
     

      将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

      3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

      # 载入预训练模型参数后...
      for name, value in model.named_parameters():
          print(name)
      # 或
      print(model.state_dict().keys())
       
       

      假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

      'encoder.visual_emb.0.weight',
      'encoder.visual_emb.0.bias',
      'viewer.bd.Wsi',
      'viewer.bd.bias',
      'decoder.core.layer_0.weight_ih',
      'decoder.core.layer_0.weight_hh',
       
       

      假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

      ignored_params = list(map(id, model.decoder.parameters()))
      base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
      optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
                                    {'params':model.decoder.parameters()}
                                    ],
                                    lr=1e-4, momentum=0.9)
       
       

      代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。
      在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有paramslr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

      参考:

      1. pytorch官方文档
      2. https://blog.csdn.net/u012759136/article/details/65634477
      • 4
        点赞
      • 8
        收藏
        觉得还不错? 一键收藏
      • 0
        评论

      “相关推荐”对你有帮助么?

      • 非常没帮助
      • 没帮助
      • 一般
      • 有帮助
      • 非常有帮助
      提交
      评论
      添加红包

      请填写红包祝福语或标题

      红包个数最小为10个

      红包金额最低5元

      当前余额3.43前往充值 >
      需支付:10.00
      成就一亿技术人!
      领取后你会自动成为博主和红包主的粉丝 规则
      hope_wisdom
      发出的红包
      实付
      使用余额支付
      点击重新获取
      扫码支付
      钱包余额 0

      抵扣说明:

      1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
      2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

      余额充值