很多时候当训练一个新的图像分类任务时,一般不会完全从一个随机的模型开始训练,而是利用预训练的模型来加速训练的过程。经常使用在ImageNet上的预训练模型。
- 这是一种transfer learning的方法。我们常用以下两种方法做迁移学习。
- fine tuning: 从一个预训练模型开始,我们改变一些模型的架构,然后继续训练整个模型的参数。
- feature extraction: 我们不再改变与训练模型的参数,而是只更新我们改变过的部分模型参数。我们之所以叫它feature extraction是因为我们把预训练的CNN模型当做一个特征提取模型,利用提取出来的特征做来完成我们的训练任务。
以下是构建和训练迁移学习模型的基本步骤:
- 初始化预训练模型
- 把最后一层的输出层改变成我们想要分的类别总数
- 定义一个optimizer来更新参数
- 模型训练
一、项目介绍:
1、目标:本文尝试采用CNN实现图像蜜蜂和蚂蚁图像分类任务。
2、数据说明:使用hymenoptera_data数据集。包括两类图片, bees 和 ants, 这些数据都被处理成了可以使用ImageFolder来读取的格式。
输入数据维度:;输出数据维度
3、torchvision的datasets.ImageFolder参数说明:
(1)data_dir:数据的存储目录。设置成数据的根目录
(2)model_name:训练时使用的模型。设置成自己的训练模型,也可以使用封装好的模型。如,resnet, alexnet, vgg, squeezenet, densenet, inception等
(3)num_classes:表示数据集分类的类别数
(4)batch_size
(5)num_epochs
(6)feature_extract:表示训练时使用fine tuning还是feature extraction方法。如果feature_extract = False,整个模型都会被同时更新。如果feature_extract = True,只有模型的最后一层被更新。
4、网络框架:
二、设置参数
import numpy as np
import torchvision
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import time
import os
import copy
print("Torchvision Version: ",torchvision.__version__)
data_dir = './hymenoptera_data'
#模型可以选择[resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = 'resnet'
#数据的label分类个数
num_classes = 2
#训练的batch_size
batch_size=32
#训练的轮数
num_epochs = 15
#整个模型参数都参数训练,进行更新
feature_extract = True
二、代码实现
1、初始化模型
#定义需要更新的参数
def set_parameter_requires_grad(model,feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
#使用resnet框架进行模型初始化
def initialize_model(model_name,num_class,feature_extract,use_pretrained=True):
if model_name == 'renet':
model_ft = models.resnet18(pretrained=use_pretrained) #使用resnet18最为初始化模型
set_parameter_requires_grad(model_ft,feature_extract) #模型中所有参数都更新
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs,num_classes) #根据分类个数重新定义最后一层全连接层
input_size = 224
return model_ft,input_size
model_ft,input_size = initialize_model(model_name,num_classes,feature_extract,use_pretrained=True)
print(model_ft)
2、导入数据集
根据模型输入的size,将数据预处理称为对应的格式
#定义图片的处理方式
data_transforms = {
'train':transforms.Comp