- 基于经典网络架构训练图像分类模型¶
数据预处理部分:
数据增强:torchvision中transforms模块自带功能,比较实用
数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
DataLoader模块直接读取batch数据
- 网络模块设置:
加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的
- 网络模型保存与测试
模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
读取模型进行实际测试
1 模块导入
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
#imageio:一个简单的接口来读取和写入各种图像数据
#sys:该模块提供对解释器使用或维护的一些变量的访问,以及与解释器强烈交互的函数
#json:使用 json 模块来对 JSON 数据进行编解码
#PIL是Python平台事实上的图像处理标准库,支持多种格式,并提供强大的图形与图像处理功能。
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image
2 数据读取与预处理操作
#路径设置
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
2.1制作好数据源:
data_transforms中指定了所有图像预处理操作
ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字
#https://blog.csdn.net/weixin_43135178/article/details/115133115
data_transforms = {
#训练集
'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(224),#从中心开始裁剪
#P表示概率,有百分之50的概率反转,剩下不反转
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
#将h,w,c[0.255]变为c,h,w[0.0,1.0]
transforms.ToTensor(),
#下面是归一化,前面是减均值,后面是比标准差
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
#验证集不需要进行数据增强
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
batch_size = 8
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
#查看
image_datasets
dataloaders
dataset_sizes
2.2 读取标签对应的实际名字
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
#查看
cat_to_name
2.3 展示下数据
注意tensor的数据需要转换成numpy的格式,而且还需要还原回标准化的结果
def im_convert(tensor):
""" 展示数据"""
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
#下面将图像还原回去,利用squeeze()函数将表示向量的数组转换为秩为1的数组,这样利用matplotlib库函数画图
#transpose是调换位置,之前是换成了(c,h,w),需要重新还为(h,w,c)
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
#clip的作用是小于0的都换成0,大于1的都变成1
image = image.clip(0, 1)
return image
fig=plt.figure