关于pytorch读取数据集的一些知识点

本文介绍了如何利用torchvision加载常见的数据集如MNIST、CIFAR10等,以及如何通过ImageFolder处理特定结构的数据。还展示了如何自定义数据集类MyDataset,以适应普通数据集的加载,并进行了数据增强操作。torchvision提供了方便的数据预处理转换,如ToTensor和Normalize,简化了深度学习项目的准备工作。
摘要由CSDN通过智能技术生成

torchvision包提供了一些常用的数据集和转换函数,使用torchvision甚至不需要自己写处理函数。

一、对于torchvision提供的数据集

对于这一类数据集,PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。
Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。

比如,加载MNIST数据集:

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    ),
    batch_size=TRAIN_BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data', train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    ),
    batch_size=TEST_BATCH_SIZE, shuffle=False)

二、对于特定结构的数据集

通过torchvision中的通用数据集ImageFolder来完成加载,它假设数据结构为如下:

root/airport/airplane(1).jpg
root/airport/airplane(2).jpg
root/airport/airplane(3).jpg
.
.
.
root/beach/beach(1).jpg
root/beach/beach(2).jpg
root/beach/beach(3)_.jpg

同样

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 数据增强
train_transforms = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),  # 随机裁剪到256*256
        transforms.RandomRotation(degrees=45),  # 随机旋转
        transforms.RandomHorizontalFlip(),      # 随机水平翻转
        transforms.CenterCrop(size=224),        # 中心裁剪到224*224
        transforms.ToTensor(),                  # 转化成张量
        transforms.Normalize([0.485, 0.456, 0.406],  # 归一化
                             [0.229, 0.224, 0.225])
])

test_valid_transforms = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406],
                              [0.229, 0.224, 0.225])
])

# 利用Dataloader加载数据
train_directory = config.TRAIN_DATASET_DIR
valid_directory = config.VALID_DATASET_DIR

batch_size = config.BATCH_SIZE
num_classes = config.NUM_CLASSES

train_datasets = datasets.ImageFolder(train_directory, transform=train_transforms)
train_data_size = len(train_datasets)
train_data = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True)  # shuffle将序列的所有元素随机排序

三、对于普通数据集

定义数据集的类MyDataset,这个类要继承Dataset这个抽象类,然后重写下面的函数:

①__len__: 使得len(dataset)返回数据集的大小;

②__getitem__:使得支持dataset[i]能够返回第i个数据样本的下标操作

通常情况还包括初始函数__init__.

# 读取图片,主要是通过Dataset类
# 通过继承torch.utils.data.Dataset的这个抽象类,可以定义我们需要的数据类
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, labels, transform=None):
        self.img_paths = img_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):  # 实现索引数据集中的某一个数据
        img_path, label = self.img_paths[index], self.labels[index]

        img = cv2.imread(img_path)                  # 接口读取图片,读进来是BGR格式数据
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 色彩空间转化函数cv2.cvtColor()进行色彩空间的转换,将BGR格式转换成RGB格式
        img = Image.fromarray(img)                  # numpy中的数组array转换成PIL中的image

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):               # 返回数据集的大小
        return len(self.img_paths)

train_set = MyDataset(
	train_img_paths,
	train_labels,
	transform=train_transform)

train_loader = torch.utils.data.DataLoader(
	train_set,
	batch_size=args.batch_size,
	shuffle=True,
	num_workers=4,
	sampler=None)

参考:链接: https://www.jianshu.com/p/6e22d21c84be.

待更新…

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一个吃吃喝喝不务正业GirL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值