加载预训练模型,模型微调,在自己的数据集上快速出效果

  • 针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune)。 PyTorch里面提供的经典的网络模型都是官方通过Imagenet的数据集与训练好的数据,如果我们的数据训练数据不够,这些数据是可以作为基础模型来使用的。(Fine tuning 模型微调)

  • Fine tuning 模型微调的好处

    • 对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠微调已经训练好的模型。

    • 可以降低训练成本:如果使用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用 CPU 都完全无压力,没有深度学习机器也可以做。

    • 前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子

  • 迁移学习初衷是节省人工标注样本的时间,让模型可以通过一个已有的标记数据的领域向未标记数据领域进行迁移从而训练出适用于该领域的模型,直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。把统一的概念抽象出来,只学习不同的内容。迁移学习按照学习方式可以分为基于样本的迁移,基于特征的迁移,基于模型的迁移,以及基于关系的迁移。

  • 微调应该是迁移学习中的一部分。微调只能说是一个trick,一种技术;迁移学习是一个更宏大的概念

  • Pytorch模型保存、加载与预训练

  • 保存和加载整个模型和参数:这种方式会保存整个模型的结构以及参数,会占用较大的磁盘空间, 通常不采用这种方式

  • torch.save(model, 'model.pkl')  #保存
    model = torch.load('model.pkl') # 加载
    
  • 保存和加载模型的参数, 优点是速度快,占用的磁盘空间少, 是最常用的模型保存方法。load_state_dict有一个strict参数,该参数默认是True, 表示预训练模型的网络结构与自定义的网络结构严格相同(包括名字和维度)。 如果自定义网络和预训练网络不严格相同时, 需要将不属于自定义网络的key去掉

  • torch.save(model.state_dict(), 'model_state_dict.pkl')
    model = model.load_state_dict(torch.load(model_state_dict.pkl))
    
  • 在实际场景中, 我们往往需要保存更多的信息,如优化器的参数, 那么可以通过字典的方式进行存储

  • # 保存
    torch.save({'epoch': epochId,
                'state_dict': model.state_dict,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()}, 
                 checkpoint_path + "/m-" + timestamp + str("%.4f" % best_acc) + ".pth.tar")
    # 加载
    def load_model(model, checkpoint, optimizer):
          model_CKPT = torch.load(checkpoint)
          model.load_state_dict(model_CKPT['state_dict'])
          optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer
    
  • 加载部分预训练模型: 如果我们修改了网络, 那么就需要将这部分参数过滤掉:(值得注意的是,当两个网络的结构相同, 但是结构的命名不同时, 直接加载会报错。因此需要修改结构的key值)

  • def load_model(model, chinkpoint, optimizer):
          model_CKPT = torch.load(checkpoint)
          model_dict = model.state_dict()
          pretrained_dict = model_CKPT['state_dict']
          # 将不在model中的参数过滤掉
          new_dict = {k, v for k, v in pretrained_dict.items() if k in model_dict.keys()}
          model_dict.update(new_dict)
          model.load_state_dict(model_dict)
          # 加载优化器参数
          optimizer.load_state_dict(model_CKPT['optimizer'])
        return model, optimizer
    
  • 冻结网络的部分参数, 训练另一部分参数(注意,必须同时在优化器中将这些参数过滤掉, 否则会报错。因为optimizer里面的参数要求required_grad为Ture)

    • 当输入给模型的数据集形式相似或者相同时,常见的是利用现有的经典模型(如Residual Network、 GoogleNet等)作为backbone来提取特征,那么这些经典模型已经训练好的模型参数可以直接拿过来使用。通常情况下, 我们希望将这些经典网络模型的参数固定下来, 不进行训练,只训练后面我们添加的和具体任务相关的网络参数。

      • 新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器

      • 新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine-tuning

      • 新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的化那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据

      • 新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错

      • 对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于(一般可设置小于10倍)初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。

  • # 以ResNet网络为例
    # 当我们加载ResNet预训练模型之后,在ResNet的基础上连接了新的网络模块, ResNet那部分网络参数先冻结不更新
    # 只更新新引入网络结构的参数
    class Net(torch.nn.Module):
          def __init__(self, model, pretrained):
              super(Net, self).__init__()
              self.resnet = model(pretained)
              for p in self.parameters():
                  p.requires_grad = False
              self.conv1 = torch.nn.Conv2d(2048, 1024, 1)
              self.conv2 = torch.nn.Conv2d(1024, 1024, 1)
    
  • 参数修改: resnet网络的最后一层对应1000个类别, 如果我们自己的数据只有10个类别, 那么可以进行如下修改

  • import torch
    import torchvision.models as models
    model = models.resnet50(pretrained=True)
    fc_inDim = model.fc.in_features
    # 修改为10个类别
    model.fc = torch.nn.Linear(fc_inDim, 10)
    
  • Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

  • 5
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

羞儿

写作是兴趣,打赏看心情

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值