pytorch 自定义Dataset类

torch.utils.data.Dataset 是 PyTorch 数据处理模块中的一个核心类,用于表示一个数据集。通过继承和自定义 Dataset 类,用户可以轻松管理和加载各种类型的数据,如图像、文本、时间序列等。

1. Dataset 类的作用

Dataset 提供了一种标准接口,方便用户自定义数据加载逻辑,尤其是对于大型数据集。每个自定义的数据集类需要实现两个核心方法:

  • __len__():返回数据集中样本的数量。
  • __getitem__(index):根据给定的索引返回数据集中的一个样本(通常包括特征和标签)。

2. 自定义 Dataset

Dataset 是一个抽象类,因此你需要通过继承它来定义自己的数据集,并实现其中的 __len__ 和 __getitem__ 方法。以下是如何自定义一个简单的 Dataset 的示例。

示例代码
import torch
from torch.utils.data import Dataset

# 自定义数据集类,继承自 torch.utils.data.Dataset
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 初始化数据集,传入数据和标签
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集中样本的数量
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引返回一个样本和其对应的标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 示例数据
data = torch.randn(100, 3)  # 100 个样本,每个样本有 3 个特征
labels = torch.randint(0, 2, (100,))  # 100 个样本的标签,二分类(0 或 1)

# 创建数据集实例
dataset = MyDataset(data, labels)

# 访问数据集中的第一个样本
sample, label = dataset[0]
print("Sample:", sample)
print("Label:", label)

解释:

  • __init__(self, data, labels):构造函数中,我们将数据和标签传入并保存为类的成员变量。
  • __len__(self):返回数据集的样本数量。
  • __getitem__(self, idx):根据索引 idx,返回数据和标签。

3. 与 DataLoader 配合使用

自定义的 Dataset 类通常与 DataLoader 配合使用。DataLoader 提供了批量数据加载、打乱顺序、并行加载等功能。

from torch.utils.data import DataLoader

# 使用 DataLoader 加载数据集
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 迭代 DataLoader
for batch_data, batch_labels in dataloader:
    print(batch_data, batch_labels)

解释:

  • batch_size=4:每次加载 4 个样本。
  • shuffle=True:在每个 epoch 之前将数据打乱。

4. 常见的 Dataset 子类

PyTorch 提供了一些常用的 Dataset 子类,如:

  • torchvision.datasets:用于加载图像数据集(如 CIFAR、MNIST 等)。
  • torchtext.datasets:用于加载文本数据集(如 IMDB、WikiText 等)。
  • torch.utils.data.TensorDataset:将一对张量(如数据和标签)封装成一个数据集。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值