Pytorch学习日记2:Datasets&DataLoaders

本文介绍了如何在PyTorch中使用torchvision加载Fashion-MNIST数据集,包括数据集的使用、自定义数据集的创建以及DataLoader的运用。教程详细展示了如何从torchvision获取数据并进行预处理,以及如何创建和使用自定义数据集进行训练数据的加载。
摘要由CSDN通过智能技术生成

Pytorch官网教程:https://pytorch.org/tutorials/  

主要内容:

1.从torchvision中加载数据集

2.自定义的数据集

3.DataLoader的使用

1.从torchvison中加载数据集

首先介绍一下torchvision,专门用来处理图像,通常应用于计算机视觉领域。常用的三个包:models(提供训练好的网络模型)、datasets(提供常用的图片数据集以及加载数据集的常用方法)、transforms(提供常见的图像转换操作)

更详细可见官网或PyTorch:Torchvision的简单介绍与使用

下面举例从torchvision中加载Fashion-MNIST数据集,难度不大,看注释即可。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor   # 图像变换包
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",       # root定义的参数为训练集或测试集存放的根目录
    train=True,        # True为训练集,False为测试集
    download=True,     # True为从互联网上下载数据如果root中没有
    transform=ToTensor()  # 使用torchvison中定义的类对输入图片进行变换,输入的图片是PIL image类型
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Iterating and Visualizing the Dataset
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1): # 1~9
    sample_idx = torch.randint(len(training_data), size=(1,)).item() # 随机生成样本索引值,item()把将张量转化为标量
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    # 原始图片大小 [1,28,28]
    plt.imshow(img.squeeze(), cmap="gray")  # img.squeeze()去掉为1的维度,显示灰度图像
plt.show()

2.自定义的数据集

torch.utils.data.Dataset方法可以通过创建Dataset类功能的子类来创建自定义数据集。

但必须覆盖__len__和__getitem__。

# Creating a Custom Dataset for your files
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image

class CustomImageDataset(Dataset):  # 继承torch.utils.data.Dataset
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file) # 标签
        self.img_dir = img_dir  # 图片
        self.transform = transform  
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # 衔接图像路径
        image = read_image(img_path)  #获取图片
        label = self.img_labels.iloc[idx, 1]  # 对应标签
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

自定义的数据集CustomImageDataset继承自torch.utils.data.Dataset。

__init__方法:主要工作是初始化类的一些函数和参数,以便在__getitem__中使用。

                        制作__getitem__函数所要用到的图片和对应标签的list。

__len__方法:返回数据集大小

__getitem__方法:读数据、预处理数据(例如torchvision.transform)、返回图像和标签

如果没有理解,有完整实例:PyTorch基础-自定义数据集和数据加载器(2)

3.DataLoader的使用

使用DataLoader准备训练的数据

# Preparing your data for training with DataLoaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# batch_size为64,一次处理一批64张图像以及对应的标签数据  # shuffle 数据洗牌

使用DataLoader进行迭代

# Iterate through the DataLoader
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))  # iter()返回一个代表数据流的对象  next()获取数据流的下一个元素
print(f"Feature batch shape: {train_features.size()}")  # [64,1,28,28]
print(f"Labels batch shape: {train_labels.size()}")     #[64]
img = train_features[0].squeeze()  # 删除第一行数据中的一维维度
label = train_labels[0]            # 获取对应标签
plt.imshow(img, cmap="gray")       # 显示灰度图像
plt.show()
print(f"Label: {label}")

张量中包含图像数据的每个维度的大小由以下每个值定义:(批量大小,颜色通道数,图像高度,图像宽度)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南风知我意95

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值