图像识别实战常用模块解读(练习,小白记录)

  1. 基于经典网络架构训练图像分类模型¶

数据预处理部分:
数据增强:torchvision中transforms模块自带功能,比较实用
数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
DataLoader模块直接读取batch数据

  1. 网络模块设置:

加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

  1. 网络模型保存与测试

模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
读取模型进行实际测试

1 模块导入

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
#pip install torchvision
from torchvision import transforms, models, datasets
#https://pytorch.org/docs/stable/torchvision/index.html
#imageio:一个简单的接口来读取和写入各种图像数据
#sys:该模块提供对解释器使用或维护的一些变量的访问,以及与解释器强烈交互的函数
#json:使用 json 模块来对 JSON 数据进行编解码
#PIL是Python平台事实上的图像处理标准库,支持多种格式,并提供强大的图形与图像处理功能。
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2 数据读取与预处理操作

#路径设置
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

2.1制作好数据源:

data_transforms中指定了所有图像预处理操作
ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字

#https://blog.csdn.net/weixin_43135178/article/details/115133115
data_transforms = {
   
    #训练集
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪
                                 #P表示概率,有百分之50的概率反转,剩下不反转
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        #将h,w,c[0.255]变为c,h,w[0.0,1.0]
        transforms.ToTensor(),
                    #下面是归一化,前面是减均值,后面是比标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    #验证集不需要进行数据增强
    'valid': transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
batch_size = 8
image_datasets = {
   x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {
   x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {
   x: len(image_datasets[x]) for x in ['train', 'valid']} 
class_names = image_datasets['train'].classes
#查看
image_datasets

在这里插入图片描述

dataloaders

在这里插入图片描述

dataset_sizes

在这里插入图片描述

2.2 读取标签对应的实际名字

with open('cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
#查看
cat_to_name

在这里插入图片描述

2.3 展示下数据

注意tensor的数据需要转换成numpy的格式,而且还需要还原回标准化的结果

def im_convert(tensor):
    """ 展示数据"""
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    #下面将图像还原回去,利用squeeze()函数将表示向量的数组转换为秩为1的数组,这样利用matplotlib库函数画图
    #transpose是调换位置,之前是换成了(c,h,w),需要重新还为(h,w,c)
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    #clip的作用是小于0的都换成0,大于1的都变成1 
    image = image.clip(0, 1)

    return image
fig=plt.figure
  • 6
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值