Pytorch实现CNN模型的迁移学习——蜜蜂和蚂蚁图片分类项目

很多时候当训练一个新的图像分类任务时,一般不会完全从一个随机的模型开始训练,而是利用预训练的模型来加速训练的过程。经常使用在ImageNet上的预训练模型。

  • 这是一种transfer learning的方法。我们常用以下两种方法做迁移学习。
    • fine tuning: 从一个预训练模型开始,我们改变一些模型的架构,然后继续训练整个模型的参数。
    • feature extraction: 我们不再改变与训练模型的参数,而是只更新我们改变过的部分模型参数。我们之所以叫它feature extraction是因为我们把预训练的CNN模型当做一个特征提取模型,利用提取出来的特征做来完成我们的训练任务。
      以下是构建和训练迁移学习模型的基本步骤:
  • 初始化预训练模型
  • 把最后一层的输出层改变成我们想要分的类别总数
  • 定义一个optimizer来更新参数
  • 模型训练

一、项目介绍:

1、目标:本文尝试采用CNN实现图像蜜蜂和蚂蚁图像分类任务。
2、数据说明:使用hymenoptera_data数据集。包括两类图片, bees 和 ants, 这些数据都被处理成了可以使用ImageFolder来读取的格式。
输入数据维度:;输出数据维度
3、torchvision的datasets.ImageFolder参数说明:
(1)data_dir:数据的存储目录。设置成数据的根目录
(2)model_name:训练时使用的模型。设置成自己的训练模型,也可以使用封装好的模型。如,resnet, alexnet, vgg, squeezenet, densenet, inception等
(3)num_classes:表示数据集分类的类别数
(4)batch_size
(5)num_epochs
(6)feature_extract:表示训练时使用fine tuning还是feature extraction方法。如果feature_extract = False,整个模型都会被同时更新。如果feature_extract = True,只有模型的最后一层被更新。
4、网络框架:

二、设置参数

import numpy as np
import torchvision
from torchvision import datasets, transforms, models

import matplotlib.pyplot as plt
import time
import os
import copy
print("Torchvision Version: ",torchvision.__version__)

data_dir = './hymenoptera_data'
#模型可以选择[resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = 'resnet'
#数据的label分类个数
num_classes = 2
#训练的batch_size
batch_size=32
#训练的轮数
num_epochs = 15
#整个模型参数都参数训练,进行更新
feature_extract = True

二、代码实现

1、初始化模型
#定义需要更新的参数
def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
#使用resnet框架进行模型初始化           
def initialize_model(model_name,num_class,feature_extract,use_pretrained=True):
    if model_name == 'renet':
        model_ft = models.resnet18(pretrained=use_pretrained)  #使用resnet18最为初始化模型
        set_parameter_requires_grad(model_ft,feature_extract)  #模型中所有参数都更新
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)  #根据分类个数重新定义最后一层全连接层
        input_size = 224
        
    return model_ft,input_size
model_ft,input_size = initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
print(model_ft) 
2、导入数据集

根据模型输入的size,将数据预处理称为对应的格式

#定义图片的处理方式
data_transforms = {
   
    'train':transforms.Comp
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值