图像识别实战(一)----数据集的预处理

图像识别实战(一)----数据集的预处理

1.模块的导入

import os
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.数据集的读取

data_dir = './flower_data'
train_dir = data_dir+ '/train'
valid_dir = data_dir+ '/valid'

3.数据集的预处理

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45度到45度之间
                                transforms.CenterCrop(224),#从中心开始裁剪
                                transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率
                                transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
                                transforms.ColorJitter(brightness=0.2, contrast=0.1,saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
                                transforms.RandomGrayscale(p=0.025),#概率转化为灰度值,3通道就是R=G=B
                                transforms.ToTensor(),#转化为Tensor格式,在预处理结束后必须添加
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#均值,标准差,经过这样处理后的数据符合标准正态分布,即均值为0,标准差为1。使模型更容易收敛。
    'valid': transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]), 

transforms.Compose()这个类的主要功能是串联图片的变换操作,类似于一个列表。

4.数据集的组织与加载

batch_size = 8
image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}#就是用来包装所使用的数据,每次抛出一批数据
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
dataset=torchvision.datasets.ImageFolder(
                       root, #图片存储的根目录
                       transform=None, #图片的预处理操作
                       target_transform=None, #对图片类别做预处理操作
                       loader=<function default_loader>, #数据集加载方式
                       is_valid_file=None)#获取图像文件的路径并检查该文件是否为有效文件
#print(dataset.classes)  #根据分的文件夹的名字来确定的类别
#print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
#print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

我们打印出 image_datasets

{'train': Dataset ImageFolder
     Number of datapoints: 3614
     Root location: F:/flower_data/train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=(-45, 45), resample=False, expand=False)
                CenterCrop(size=(224, 224))
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 56
     Root location: F:/flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=PIL.Image.BILINEAR)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )}

我们打印出 dataset_sizes 帮助理解{}中的逻辑

{'train': 3614, 'valid': 56}

5.数据集图像展示

def im_convert(tensor):
    """展示数据"""
    image = tensor.to('cpu').clone().detach()#将Tensor数据从GPU放到CPU,复制和这个Tensor并且去掉梯度
    image = image.numpy().squeeze()#祛除数组中为1 的维度
    image = image.transpose(1,2,0)#Pytorch中为[Channels, H, W],而plt.imshow()中则是[H, W, Channels],所以交换一下通道
    image = image*np.array((0.229, 0.224, 0.225))+np.array((0.485, 0.456, 0.406))# 反转一下transforms.Normalize()的过程
    image = image.clip(0, 1)#归一化
    return image
fig = plt.figure(figsize=(20, 12))#设置图像尺寸
columns = 4
rows = 2
#我们设置的一个batchsize=8,所以dataloaders里只有8张图片,最多显示8张图片
dataiter = iter(dataloaders['valid'])#iter()迭代器
inputs, classes = dataiter.next()
for idx in range (columns*rows):
    ax = fig.add_subplot(rows,columns, idx+1,xticks=[], yticks=[])#图像区域划分row行,colums列,第idx+1个
    ax.set_title(class_names[classes[idx].item()])
    plt.imshow(im_convert(inputs[idx]))

plt.show()   

在这里插入图片描述

  • 3
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

NAND_LU

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值