利用Pytorch搭建简单的图像分类模型(之一)---读取数据

基于Pytorch的几种读取图像数据的方式(分类模型)

一、读取Pytorch官方所包含的数据集(如Imagenet,CIFAR10,MNIST)

对于Pytorch自带的数据集,只需要调用torchvision.datasets.XXXX()即可,例如想要读取CIFAR10数据集:torchvision.datasets.CIFAR10()

'''导入读取图片数据所需要的工具包'''
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data

在这里先介绍一下torchvision.transforms模块,这个模块对于图像数据来说是非常重要的,它提供了常用的一些数据增强的方法(裁剪、翻转、旋转、尺寸变换等)。

  1. 裁剪(Crop)
    裁剪主要分为中心裁剪、随机裁剪等,我自己主要常用的有:中心裁剪transforms.CenterCrop(size),随机裁剪:transforms.RondomCrop(size),随机长宽比裁剪:transforms.RandomResizedCrop(size)

    代码作用
    transforms.CenterCrop(size)从图像中间裁剪出尺寸为size的图片
    transforms.RondomCrop(size)从图像中随机裁剪出尺寸为size的图片
    transforms.RandomResizedCrop(size)随机大小、长宽比裁剪出尺寸为size的图片

    其中size就是输入图像通过transforms变换后的尺寸(通常也就是模型输入图片的尺寸),当然还有其他的参数,但是我用的比较少一般都是默认,一般只需要修改尺寸就行,size如果是一个int值(size=n),则最后裁剪得到的是一个正方形图像(尺寸为n×n);size如果是一个数对(h,w),则最后裁剪得到的图像为一个矩形(尺寸为h×w)

  2. 翻转、旋转(Flip and Rotation)
    自己用的比较多的有:依概率p水平翻转:transforms.RandomHorizontalFlip(p) 依概率p垂直翻转:transforms.RandomVerticalFlip(p) 随机旋转:transforms.RandomRotation(degrees)

    代码作用
    transforms.RandomHorizontalFlip(p)按照概率p来对图像进行水平翻转
    transforms.RandomVerticalFlip(p)按照概率p来对图像进行垂直翻转
    transforms.RandomRotation(degrees)在给定角度范围内随机旋转图片

    其中p为旋转的概率,这个根据不同需求选择不同的数值;而对于随机旋转有点不一样,degrees为旋转的角度范围,是一个数对(-degrees, +degrees)。

  3. 图像变换(resize)
    自己在读取数据时常用的有:尺寸变换:transforms.Resize(size)、标准化:transforms.Normalize(mean, std)以及将图片数据转为tensor类型并归一化至[0, 1]:transforms.ToTensor()

    代码作用
    transforms.Resize(size)将图像尺寸变换成size
    transforms.Normalize(size)将输入图像按批次标准化
    transforms.ToTensor()将输入图像归一化成[0, 1]之间的tensor

    这里详细讲解一下transforms.Normalize(mean, std)的作用

    • 假设经过裁剪、旋转等操作过后的图像数据为x(x∈[0, 255]);
    • 经过transforms.ToTensor()后,x=(x/255)∈[0, 1];
    • 最后经过transforms.Normalize(mean, std)(假设mean=0.5,std=0.5),x=[(x-0.5)/0.5]∈[-1, 1] -> (x_mean=0,x_std=1)

    其中size为转换后预期的图像尺寸,mean代表图像的均值、std代表图像的方差。 这里要注意的是,transforms.ToTensor()是直接将图片数据的每个像素值除以255。

上述针对图像的各类操作函数没有固定的搭配,都是根据具体需要具体去选择,最后,transforms模块的整体代码为:

data_transform = transforms.Compose([transforms.CenterCrop(32),            # 对输入图像进行中心裁剪,裁剪后的图像尺寸为3*32*32(对于3通道RGB图像)
                                    transforms.RandomHorizontalFlip(0.5),  # 对输入图像按照0.5的概率进行水平翻转
                                    transforms.ToTensor(),                 # 将PIL Image或者ndarray类型的数据转换为tensor类型(非常关键,一定要加)
                                    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]) # 对图像数据进行标准化
                 

通过torchvision.transforms编辑好所需要的图像变换规则后,就可以调用torchvision.datasetstorch.utils.data.DataLoader来读取与加载所需要的图片数据:

batchsize = 128
'''获取训练集与测试集数据'''
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=data_transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=data_transform, download=True)
'''加载训练集与测试集数据'''
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
Files already downloaded and verified
Files already downloaded and verified

训练集和测试集数据准备完毕,下面可以开始进行模型的训练与测试了

  1. datasets.CIFAR10(root=‘./data’, train=True, transform=data_transform, download=True)
    • root=:指出数据存放的路径,可以是绝对路径也可以是相对路径。若路径已存在,则数据直接保存在相应路径下;若路径不存在,则先创建路径然后将数据保存至路径下;
    • train=:bool型变量,若为True则表示获取训练集数据;若为False则表示获取测试集数据;
    • transform=:调用数据增强的各种方法,这里调用的是上文中所编写的data_transform,可视情况而定;
    • download=:bool型变量,这里一般都设置为True,在第一次运行时会自动将所需数据集下载到指定路径下,后续运行不会重复下载(路径中已存在相应数据集)。
  2. datasets.CIFAR10(root=‘./data’, train=True, transform=data_transform, download=True)
    • train_dataset:指出要加载的数据集,在这里为train_dataset表示要加载训练数据集,为test_dataset则表示要加载测试数据集;
    • batch_size=:设定一批次的图片数量,这个超参没有固定值,根据数据集、模型大小以及硬件配置自行设置;
    • shuffle=:bool型变量,若为True,则将打乱加载数据的顺序,若为False则不会打乱顺序;
    • num_workers=:可以理解为加载数据的通道数量,一般来说num_workers越多加载数据速度越快。

若要可视化所读取和加载的数据,可以调用以下方法(这里只做简单的介绍,通常可视化是用来简单检测读取加载数据是否正确,训练时一般会去掉):

'''可视化数据集所包含的类别'''
classes = train_dataset.classes
print(classes)
print('-----------')

'''可视化数据集的数量与各个图片的尺寸'''
print(train_dataset.data.shape) # (50000, 32, 32, 3),50000表示数据集图片的数量;按顺序:32表示h,32表示w,3表示c。
print('-----------')

'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_loader:
    print(data.shape)   # 一批次的图片数量为128(batchsize),图片的尺寸为(3, 32, 32),。
    print(label.shape)  # 每个图片对应的标签
    break

'''可视化一些样本图片'''
import matplotlib.pyplot as plt # 一个画图的包
plt.imshow(train_dataset.data[1])
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
-----------
(50000, 32, 32, 3)
-----------
torch.Size([128, 3, 32, 32])
torch.Size([128])

<matplotlib.image.AxesImage at 0x7f575402f910>

在这里插入图片描述

二、加载文件夹类型的图片数据集

针对这种方式,用的相对就多了,其对应的图片数据储存方式应该为:
根路径/数据集文件夹(weather_5)/train(test)/类别文件夹(label)/图片

这里以自己做的一个天气分类模型所用到的数据集为例:
在这里插入图片描述
在这里插入图片描述

保存成这样的格式之后,就可以直接利用pytorch定义好的派生类ImageFolder来读取了。ImageFolder其实就是Dataset的派生类,专门被定义来读取特定格式的图片的,它也是torchvision库帮中为了我们方便读取文件夹类型的图片数据而创建的。

'''导入读取图片数据所需要的工具包'''
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data

batchsize = 2

'''同样先定义transform'''
data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(0.5),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

'''获取训练集与测试集数据'''               
train_dataset = datasets.ImageFolder(root='./data/weather_5/train/', transform=data_transform)
test_dataset = datasets.ImageFolder(root='./data/weather_5/test/', transform=data_transform)

'''加载训练集与测试集数据'''
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)

这里只对datasets.ImageFolder做一个简单的介绍。
通过train_dataset = datasets.ImageFolder后返回的train_dataset包含以下三种属性:

  • train_dataset.class:用一个list保存数据集中类别名称
  • train_dataset.class_to_idx:类别对应的数字索引
  • train_dataset.imgs:保存(img_path, class)tuple的 list

重点关注下面两块小代码,可以更清晰的了解datasetsDataLoader的差别

'''可视化读取到的数据集第一张图片的尺寸与其标签'''
for data, label in train_dataset:
    print(data.shape)   # 第一张图片的为(3, 224, 224)
    print(label)        # 第一张图片的标签为0
    break
print('-----------')

'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_loader:
    print(data.shape)   # 一批次的图片数量为2(batchsize),图片的尺寸为(3, 224, 224)
    print(label.shape)  # 每个图片对应的标签
    print(label[0])     # 可以将torch.utils.data.DataLoader()中的shuffle变量修改为False,比较输出有什么不同。
    break

print(train_dataset.class_to_idx)
torch.Size([3, 224, 224])
0
-----------
torch.Size([2, 3, 224, 224])
torch.Size([2])
tensor(4)
{'cloudy': 0, 'haze': 1, 'rainy': 2, 'snow': 3, 'sunny': 4}

三、根据.CSV文件加载图像数据

这种加载图像数据的方式应该是现在用的最多的一种方式,这种方式Pytorch就没有现成的方法让我们直接去加载数据了,但是我们可以基于Pytorch定义我们自己的Dataset类。
这里以另一个天气分类模型所用到的数据集为例:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

'''导入读取图片数据所需要的工具包'''
import torch.utils.data
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import os
import pandas as pd
from PIL import Image

'''定义transform'''
data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(0.5),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

'''定义自己的Dataset类'''
class MyDataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        """
            csv_file: 标签文件的路径.
            root_dir: 所有图片的路径.
            transform: 一系列transform操作
        """
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        # print(self.data_frame.info) # 可以查看通过pandas读取后的data_frame具体包含些什么

    def __getitem__(self, idx):
        '''获取图片'''
        img_path = os.path.join(self.root_dir, 
                                self.data_frame.iloc[idx, 0]) #获取图片所在路径          
        img = Image.open(img_path).convert('RGB')   # 防止有些图片不是RGB格式
        
        '''获取标签'''
        label_number = self.data_frame.iloc[idx, 1] # 获取图片的类别标签
        
        '''判断是否要进行图像变换'''
        if self.transform:
            img = self.transform(img)

        return img, label_number # 返回图片和标签

    def __len__(self):
        return len(self.data_frame) # 返回数据集长度用来建立索引idx

'''调用自己的Dataset类来读取数据'''
train_dataset = MyDataset(csv_file='./data/weather/Train_label.csv',
                          root_dir='./data/weather/Train',
                          transform=data_transform)
'''将读取好的数据进行加载'''                          
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=2, num_workers=8) # 加载数据集

'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_iter: # 迭代batchsize中的数据
    print(data.shape) # torch.Size([128, 3, 224, 224])
    print(label.shape) # torch.Size([128])
    break

torch.Size([2, 3, 224, 224])
torch.Size([2])

总结一下,在定义自己的Dataset类的时候注意三点就可以:

  • def __init__(self):声明并初始化需要用到的变量,比如.csv路径、图片数据存放的文件夹路径等
  • def __getitem__(self, idx):根据路径逐一读取样本的图像和标签,并返回一个元组
  • def __len__(self):获取数据集中样本的数量

这里简单介绍一下上一块代码用到的几个工具包:import osimport pandas as pdfrom PIL import Image:

  • import os:在python下写程序,需要对文件以及文件夹或者其他的进行一系列的操作
  • import pandas as pd:pandas这个库是用来读取数据的,比如.csv文件以及Excel文件,上一块代码用到的pd.read_csv(csv_file)就是读取.csv文件的指令
  • from PIL import Image:Python图像库PIL(Python Image Library)是Python的第三方图像处理库,可以做很多和图像处理相关的事情:图像归档、可视化、处理等,如上一块代码中用到的img = Image.open(img_path).convert('RGB')来读取图片数据,并将其转化成RGB模式

在这里详细介绍一下有关os模块常用的指令,对于其它两个模块由于涉及的内容太多,可以去官网查看具体用法。

  • import os:在python下写程序,需要对文件以及文件夹或者其他的进行一系列的操作,就需要引入os模块,常用的指令有:

    指令作用
    os.path.join(path_1, path_2)将路径path_1和path_2拼接起来形成新路径
    os.path.split(path)将path分割成目录和文件名并以元组方式返回
    os.splitext(path)分离扩展名然后按照元组返回
    os.path.exists(path)如果path是一个存在的路径,返回True,否则返回 False

四、总结

总的来说,对于计算机视觉的数据读取和加载,就是利用好torchvision中的transforms和datasets以及torch.utils.data中的Dataloader类,当然还有很多种读取图像、视频数据的方法。

  1. 上来先查看数据集,不管是文件夹也好还是.CSV文件也好,先弄清楚数据集的结构,这样才能知道用何种方式去读取、加载数据;
  2. 了解数据集结构以后,根据自身任务调用torchvision.transforms去设计自己的transform模块;
  3. 根据不同的数据集结构,采用不同的Dataset方式去读取数据集。对于Dataset只需要把握住一点:通过路径去一张一张的获取图像和标签,最后返回值是一个包含多个元组[tupel(img, label)]的一个list;
  4. 最后通过Dataloader对Dataset进行批量加载。

完整项目在我上传的资源里面,需要的可以自取

  • 3
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值