【深度之眼】【Pytorch打卡第16天】:模型微调Finetune(迁移学习)

任务

任务简介

了解transfer learning 与 model finetune

详细介绍

学习模型微调(Finetune)的方法,以及认识Transfer Learning(迁移学习)与Model Finetune之间的关系。


知识点

Transfer Learning & Model Finetune

迁移学习:机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)。

模型微调:所谓的模型微调,其实就是模型的迁移学习,在深度学习中,通过不断的迭代,更新卷基层中的权值,这里的权值可以称之为 knowledge , 然后我们可以将这些 knowledge 进行迁移,主要目的是将这些 knowledge 运用到新的模型中,这样既可以减小由于数据量不足导致的过拟合现象,同时又能加快模型的训练速度

具体说来,对于卷积神经网络,我们可以把前面的卷基层,池化层看作是 feature extactor(特征提取) ,是一个非常有共性的部分。得到一系列的feature map。
后面的全连接层,可以称之为 classifier (分类器), 与具体的任务有关。这一部分就需要针对不同的训练任务进行调整,尤其是最后一层需要根据任务进行相应的调整。

PyTorch中的Finetune

模型微调步骤

  1. 获取预训练模型参数—原任务中获取得到的知识
  2. 加载模型(load_state_dict)
  3. 修改输出层

模型微调训练方法

  1. 固定预训练的参数(requires_grad =False;lr=0)
  2. Features Extractor较小学习率(params_group)

实战-Resnet-18 用于二分类

Resnet-18 模型介绍

数据下载
模型下载

蚂蚁蜜蜂二分类数据
训练集:各120~张
验证集:各70~张
在这里插入图片描述

Resnet-18模型结构如下图所示:

前面四层是特征提取,接下来四层(layer1~layer4)是残差网络,然后接avgpool池化层,最后接FC分类(原模型是1000分类,ImageNet上训练的)。

迁移结果分析

(1) 直接训练
如果不采用Resnet-18模型进行Finetune,直接对二分类数据进行训练,得到的Loss曲线

损失值一直在0.6附近,并且得到的Accuracy只有70%

(2) 迁移训练,但不冻结卷积层,固定学习率

# ============================ step 2/5 模型 ============================
# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)

resnet18_ft.to(device)

可以看出,损失值最后收敛到在0.2附近,并且在第二个Epoch的Accuracy就达到了90%。

(3)迁移训练,冻结卷基层,固定学习率

从代码看,所谓冻结卷积层,是直接把参数的梯度设置为False

for param in resnet18_ft.parameters():
    param.requires_grad = False
# ============================ step 2/5 模型 ============================

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值