11.2 模型finetune

一、Transform Learning 与 Model Finetune

二、pytorch中的Finetune

 

一、Transfer Learning 与 Model Finetune

1. 什么是Transfer Learning?

迁移学习是机器学习的一个分支,主要研究源域的知识如何应用到目标域当中。迁移学习是一个很大的概念。

怎么理解源域的知识应用到目标域当中呢?上图是来自一篇迁移学习的综述。左边是传统机器学习的过程,对于不同的任务分别学习得到不同的模型。而右边是迁移学习的示意图,不同的任务会划分为源任务和目标任务,对原任务进行学习,学习到的称之为知识,而我们回利用知识和目标任务进行学习,得到模型。这个模型不仅用到了目标任务,还用到了原任务的知识。

迁移学习就是将源任务的知识应用到目标任务中。

2. 迁移学习与finetune之间的关系

我们训练一个模型,就是不断地更新他的权值。而整个模型最重要的东西也就是他的权值。这个权值呢,也就可以称之为他的知识。而这些知识是可以进行迁移的。我们把这些知识迁移到新任务中,这就是模型微调。

 

为什么我们使用model finetune这个trick呢?这是因为在新任务中,数据量较小。

我们来看,神经网络该如何迁移。我们对神经网络,通常会划分为两部分,前面一些列的卷积池化,我们认为是特征提取。后面一些全连接层,我们称之为分类器。

我们对特征提取的部分,认为是比较有共性的地方。而分类器的参数呢,我们认为它与具体的任务有关,通常需要去改变。在这里,有个非常重要的地方,通常都要去改变,这就是最后一个输出层。比如原来是千分类任务,这里是二分类任务,这就需要改变。

 

二、pytorch中的Finetune

下面我们来看模型finetune需要哪些步骤。

构建好模型之后,在训练时也会常用一些trick。

1. 固定预训练的参数(两种方法:(1) requires_grad = False    (2)学习率设为0)

2. 使用较小的学习率。这时候就要用到params_group(参数组)的概念,让不同的部分学习率不同。

 

三、举例 

下面使用Resnet-18进行finetune,来完成时频图二分类任务。

 

(1)准备工作

模型下载:https://download.pytorch.org/models/resnet18-5c106cde.pth

数据准备:

|----data

        |----pubu                                       #下载的数据集。

                   |----train

                            |----saopin

                            |----wurenji

                   |----test

        |----resnet18-5c106cde.pth        #预训练的模型

|----src

         |-----finetune_resnet18.py

|----tools                                           #通用的一些函数

         |----my_dataset.py                  #Dataset

 

(2)resnet18模型结构

首先是卷积,BN,ReLU,Pool这么一组操作,我们认为是初步的特征提取。

然后是4个残差的blok,会进行一系列的特征提取。

再后面是一个池化。

最后是fc。这个fc是1000分类的任务。

 

(3)代码

例1:不使用trick:所有的参数使用同一个学习率。

finetune_resnet18_1.py

# -*- coding: utf-8 -*-
"""
# @file name  : finetune_resnet18_1.py
# @brief      : 模型finetune方法,方法一:使用同一个学习率。
"""
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt

import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from tools.my_dataset import PubuDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))

set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}

# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7


# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "data", "pubu"))
if not os.path.exists(data_dir):
    raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(
        data_dir, os.path.dirname(data_dir)))

train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_d
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
提供的源码资源涵盖了安卓应用、小程序、Python应用和Java应用等多个领域,每个领域都包含了丰富的实例和项目。这些源码都是基于各自平台的最新技术和标准编写,确保了在对应环境下能够无缝运行。同时,源码中配备了详细的注释和文档,帮助用户快速理解代码结构和实现逻辑。 适用人群: 这些源码资源特别适合大学生群体。无论你是计算机相关专业的学生,还是对其他领域编程感兴趣的学生,这些资源都能为你提供宝贵的学习和实践机会。通过学习和运行这些源码,你可以掌握各平台开发的基础知识,提升编程能力和项目实战经验。 使用场景及目标: 在学习阶段,你可以利用这些源码资源进行课程实践、课外项目或毕业设计。通过分析和运行源码,你将深入了解各平台开发的技术细节和最佳实践,逐步培养起自己的项目开发和问题解决能力。此外,在求职或创业过程中,具备跨平台开发能力的大学生将更具竞争力。 其他说明: 为了确保源码资源的可运行性和易用性,特别注意了以下几点:首先,每份源码都提供了详细的运行环境和依赖说明,确保用户能够轻松搭建起开发环境;其次,源码中的注释和文档都非常完善,方便用户快速上手和理解代码;最后,我会定期更新这些源码资源,以适应各平台技术的最新发展和市场需求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值