pytorch数据集加载之DataSet和DataLoader

在学习深度学习初期,被各种算法、各种名词吸引,学习了一些相关知识、框架和方法,也跑过Mnist分类以及iris回归。但是当真正拿到一个陌生的数据集时,需要重头开始搭建一个完整的模型时,常常会感觉到无从下手,之前跑的模型都是用人家整理好的数据,一行代码就能把数据加载进行直接使用,自己完全不关心数据的加载、处理过程,只关心模型能不能训练,训练结果怎么样。这篇文章简单记录一下pytorch中自定义数据集的使用方法。

在pytorch中涉及到数据集加载的模块有两个,一个是DataSet,另一个是DataLoader。pytorch中数据加载的核心是torch.utils.data.DataLoader类,支持映射类型(map-style)和迭代类型(iterable-style)的数据集。

DataSet的描述如下:

根据描述可知,Dataset是一个抽象类,子类需要实现其中的__getitem__()方法用于获取数据集中的元素,实现__len__()方法用于获取数据集的大小。

DataLoader的构造函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中,dataset是一个加载数据集的对象,batch_size是批量大小,num_workers表示用几个子线程来并行加载数据。


下面讲一下pytorch数据加载支持的两种类型:

1、映射类型map-style

根据官网的描述,map-style类型的数据集可以简单理解为键值对类型的数据集,键可以是字典的key,也可以是数组的index,总之可以通过类似于dataset[idx]这种方式进行访问。要加载map-style类型的数据,需要重写Dataset的__getitem__()和__len__()方法。

2、迭代类型iterable-style

加载iterable-style类型的数据,需要实现IterableDataset类的__iter__()方法,适用于加载流式等不便于进行shuffle的数据。


下面通过两种典型场景讲一下使用pytorch加载map-style类型的数据集方法。

场景一:加载pandas.DataFrame或numpy数组,此种场景常出现在处理回归问题时加载原始数据集。

import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader

# 加载pandas.DataFrame,需要使用values将DataFrame先转换为numpy数组
#  构造numpy数组
data_X = np.random.randn(100, 5)
data_y = 3 * data_X + 5

# 自定义Dataset的子类
class MyDataset(Dataset):
    # 构造器初始化方法
    def __init__(self, data_X, data_y):
        self.data_X = data_X
        self.data_y = data_y
    
    # 重写getitem方法用于通过idx获取数据内容
    def __getitem__(self, idx):
        return self.data_X[idx], self.data_y[idx]
    
    # 重写len方法获取数据集大小
    def __len__(self):
        return self.data_X.shape[0]

# 构造Dataset对象
dataset = MyDataset(data_X, data_y)
# 构造DataLoader对象
dataloader = DataLoader(dataset, batch_size=16, num_workers=0,shuffle=True, drop_last=False)

for batch_X, batch_y in dataloader:
    print(batch_X.shape, batch_y.shape)

输出结果如下:

torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([16, 5]) torch.Size([16, 5])
torch.Size([4, 5]) torch.Size([4, 5])

场景二:加载磁盘上的图像数据,此场景常出现在处理图像分类问题时加载图像数据集。

图像数据的存放格式如下,猫的图片存放在cat目录下,狗的图片存放在dog目录下:

from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision
import random
from matplotlib import pyplot as plt

class MyDataset(Dataset):
        # 构造器初始化方法
    def __init__(self, filenames, labels, transforms=None):
        self.filenames = filenames
        self.labels = labels
        self.transforms = transforms
    
    # 重写getitem方法用于通过idx获取数据内容
    def __getitem__(self, idx):
        # 使用Pillow Image读取图片文件
        image = Image.open(self.filenames[idx]).convert("RGB")
        # 对图像数据进行转换
        if self.transforms is not None:
            image = self.transforms(image)
        return image, self.labels[idx]
    
    # 重写len方法获取数据集大小
    def __len__(self):
        return len(self.filenames)
    
def show_image(images, labels, classes):
    fig, axes = plt.subplots(1, 4, figsize=(15, 8))
    for index, image in enumerate(images):
        # pytorch中Tensor的shape是[C, H, W],使用matplotlib显示时,需要转换shape到[H, W, C]
        image = image.numpy().transpose(1, 2, 0)
        label = labels[index]
        axes[index].set_title(classes[label])
        axes[index].imshow(image)
    

# 定义图像预处理转换方法
transforms = torchvision.transforms.Compose(
    [
        # torchvision.transforms处理的目标是Image对象
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.RandomGrayscale(p=0.3),
        # 将Image对象转换为Tensor张量
        torchvision.transforms.ToTensor()
    ]
)

image_dataset = torchvision.datasets.ImageFolder("../data/cat_and_dog")
# image_dataset.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_dataset.samples)
# image_dataset.classes 列表中存放的类别顺序与image_dataset.samples中存放的类别索引编号相对应
classes = image_dataset.classes
# print(image_dataset.samples[:5])

# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_dataset.samples:
#     print(image_path, label)
    filenames.append(image_path)
    labels.append(label)

dataset = MyDataset(filenames, labels, transforms)
dataloader = DataLoader(dataset,batch_size=4, num_workers=0, shuffle=True, drop_last=False)

for images, labels in dataloader:
    print(images.shape, labels)
    # 显示读取到的图像数据,并验证类别信息是否真确
    show_image(images, labels, classes)

输出结果如下:

torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 1])
torch.Size([4, 3, 224, 224]) tensor([1, 1, 0, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 0, 1, 0])
torch.Size([4, 3, 224, 224]) tensor([0, 1, 1, 0])

 

 

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了三个主要的工具:DatasetDataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len__和__getitem__方法来实现自己的数据逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader数据器,用于对数据集进行批量。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader数据准备和模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更方便和高效。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值