Pytorch自定义Dataset,及__getitem__()访问

在继承 torch.utils.data.Dataset 并重写其中的 __getitem__() 方法后,调用该方法的方式与调用任何类的方法类似。以下是一个完整的示例,展示如何实现和调用重写的 __getitem__() 方法。

实现自定义 Dataset 类

首先,实现继承自 torch.utils.data.Dataset 的自定义类,并重写 __getitem__() 方法。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        """
        初始化自定义数据集。

        Args:
            data (list): 数据列表。
        """
        self.data = data

    def __len__(self):
        """
        返回数据集的大小。

        Returns:
            int: 数据集的大小。
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        根据索引获取数据项。

        Args:
            idx (int): 数据项的索引。

        Returns:
            data: 索引对应的数据项。
        """
        return self.data[idx]

实例化数据集并调用 __getitem__() 方法

可以通过实例化自定义数据集类,然后直接调用 __getitem__() 方法来获取数据项。注意,通常不直接调用 __getitem__() 方法,而是通过 dataset[idx] 的方式调用。

# 创建示例数据
data = [1, 2, 3, 4, 5]

# 实例化自定义数据集
dataset = MyDataset(data)

# 直接调用 __getitem__() 方法
item = dataset.__getitem__(0)
print(f"直接调用 __getitem__(): {item}")

# 通过索引访问数据项(推荐方式)
item = dataset[0]
print(f"通过索引访问数据项: {item}")

# 使用 DataLoader
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历 DataLoader
for batch in dataloader:
    print(f"批次数据: {batch}")

输出结果

直接调用 __getitem__(): 1
通过索引访问数据项: 1
批次数据: tensor([2, 3])
批次数据: tensor([4, 5])
批次数据: tensor([1])

总结

  1. 实现自定义数据集类,并重写 __getitem__() 方法。
  2. 通过实例化自定义数据集类并使用索引访问数据项。
  3. 通常使用 dataset[idx] 语法来调用 __getitem__() 方法,而不是直接调用。
  4. 可以使用 DataLoader 来处理批量数据。

这样可以方便地管理和处理数据集,同时利用 PyTorch 提供的其他功能,如批量加载和数据打乱等。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值