%matplotlib inline
3.2 迁移学习(transfer learning)
关于迁移学习的一些基础知识可以参考王晋东的迁移学习手册:https://max.book118.com/html/2018/0902/5144341211001312.shtm
迁移学习的意义何在?
在实践中,很少有人从头开始训练整个卷积网络(随机初始化),因为拥有足够大小的数据集是相对罕见的。相反,通常在非常大的数据集(例如ImageNet,其包含具有1000个类别的120万个图像)上预先训练ConvNet,然后使用ConvNet初始化我们要训练的网络或者作为网络的固定参数。
迁移学习的两种使用场景:
1.微调网络。使用已经训练好的网络,如在imagenet 1000数据集上训练的网络,去初始化我们要训练的网络,而不是采用随机初始化。
2.作为固定特征提取器。将要训练网络的前面的层,除了最后的全连接层,全都用已经训练好的网络的这些层的参数固定住,只有最后的全连接层利用随机初始化参数的方式去训练。
下面是pytorch提供的pre-trained models:
class torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda,last_epoch=-1)
将每一个参数组的学习率设置为初始学习率lr的某个函数倍.当last_epoch=-1时,设置初始学习率为lr.
参数:
optimizer(Optimizer对象)–优化器
lr_lambda(是一个函数,或者列表(list)):当是一个函数时,需要给其一个整数参数,使其计算出一个乘数因子,用于调整学习率,通常该输入参数是epoch数目或者是一组上面的函数组成的列表
last_epoch(int类型):最后一次epoch的索引,默认为-1.
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
plt.ion() # interactive mode
下载数据包
使用torchvision和torch.utils.data包来加载数据
要解决的问题是训练一个模型来对蚂蚁和蜜蜂进行分类,每一类有120个训练图像,和75个验证图像,这里,我们使用迁移学习来训练这个模型。
这个数据集是imagenet下的一个子集。
transforms.RandomResizedCrop
:
将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小(即先随机采集,然后对裁剪得到的图像缩放为同一大小).该操作的含义:即使只是该物体的一部分,我们也认为这是该类物体,比如 猫的图片别裁剪缩放后,仍然认为这是一个猫.
torchvision.transforms.RandomHorizontalFlip
:
以给定的概率随机水平翻转给定的PIL图像,参数:p(float) - 图像被翻转的概率。默认值为0.5
torchvision.transforms.Resize(size, interpolation=2)
:
如果size是类似(h,w),则输出大小将与此匹配;如果只有一个整数size,则图像的较小边缘将与此数字匹配,如果高度>宽度,则图像将重新缩放为(size * height / width, size)
torchvision.transforms.CenterCrop(size)
:
在中心裁剪给定的PIL图像,如果size是一个整数,则裁剪尺寸为(size,size)
PIL 读出来的image格式是图片的(width, height)。
ImageFolder 在pytorch Dataset下面的一个类,ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字。
ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
root : 在指定的root路径下面寻找图片 ;
transform: 对PIL Image进行转换操作,transform 输入是loader读取图片返回的对象 ;
target_transform :对label进行变换 ;
loader: 指定加载图片的函数,默认操作是读取PIL image对象
os.path.join(data_dir,x):把路径和文件名合在一起;
data_transforms = {
'train' : transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]),
'val' : transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]),
}
data_dir = '/home/dhb/jupyter notebook/data/hymenoptera_data/'
image_datasets = {
x:datasets.ImageFolder(os.path.join(data_dir,x),
data_transforms<