在继承 torch.utils.data.Dataset
并重写其中的 __getitem__()
方法后,调用该方法的方式与调用任何类的方法类似。以下是一个完整的示例,展示如何实现和调用重写的 __getitem__()
方法。
实现自定义 Dataset 类
首先,实现继承自 torch.utils.data.Dataset
的自定义类,并重写 __getitem__()
方法。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
"""
初始化自定义数据集。
Args:
data (list): 数据列表。
"""
self.data = data
def __len__(self):
"""
返回数据集的大小。
Returns:
int: 数据集的大小。
"""
return len(self.data)
def __getitem__(self, idx):
"""
根据索引获取数据项。
Args:
idx (int): 数据项的索引。
Returns:
data: 索引对应的数据项。
"""
return self.data[idx]
实例化数据集并调用 __getitem__()
方法
可以通过实例化自定义数据集类,然后直接调用 __getitem__()
方法来获取数据项。注意,通常不直接调用 __getitem__()
方法,而是通过 dataset[idx]
的方式调用。
# 创建示例数据
data = [1, 2, 3, 4, 5]
# 实例化自定义数据集
dataset = MyDataset(data)
# 直接调用 __getitem__() 方法
item = dataset.__getitem__(0)
print(f"直接调用 __getitem__(): {item}")
# 通过索引访问数据项(推荐方式)
item = dataset[0]
print(f"通过索引访问数据项: {item}")
# 使用 DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 遍历 DataLoader
for batch in dataloader:
print(f"批次数据: {batch}")
输出结果
直接调用 __getitem__(): 1
通过索引访问数据项: 1
批次数据: tensor([2, 3])
批次数据: tensor([4, 5])
批次数据: tensor([1])
总结
- 实现自定义数据集类,并重写
__getitem__()
方法。 - 通过实例化自定义数据集类并使用索引访问数据项。
- 通常使用
dataset[idx]
语法来调用__getitem__()
方法,而不是直接调用。 - 可以使用
DataLoader
来处理批量数据。
这样可以方便地管理和处理数据集,同时利用 PyTorch 提供的其他功能,如批量加载和数据打乱等。