pytorch教程4-----自定义数据集

文章介绍了如何在PyTorch中创建自定义图像数据集类,该类需实现__init__,__len__,和__getitem__方法。然后利用DataLoader进行数据预处理,如批量加载和随机洗牌,以便于模型训练。最后展示了如何遍历DataLoader获取训练样本。
摘要由CSDN通过智能技术生成

step1:自定义数据集类必须实现三个函数:__init__、__len__和__getitem__。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    """
    __init__函数在实例化数据集对象时运行一次。我们初始化 包含图像、注释文件和两个转换的目录
    """

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    """
    __getitem__函数加载并从给定索引处的数据集返回样本。 根据索引,它识别图像在磁盘上的位置,将 
    其转换为张量,使用 ,检索 中 CSV 数据的相应标签,调用其上的转换函数(如果适用),并返回 元 
    组中的张量图像和相应的标签。
    """

step2:准备数据以使用数据加载程序进行训练。(检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望 在“小批量”中传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的 加快数据检索速度。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#training_data为step1创建的
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

step3:遍历 DataLoader。(我们已将该数据集加载到DataLoader中,并可以根据需要循环访问该数据集。 下面的每次迭代都会返回一批 图像and标签(分别包含特征和标注)。 因为我们指定了shuffle=True(随机) ,在我们遍历所有批次后,数据将被洗牌。DataLoadertrain_featurestrain_labelsbatch_size=64shuffle=True

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

结果:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值