使用DataLoader的小例子,这里CustomDataset类的__getitem__方法需要返回tensor。
加载到DataLoader中之后,DataLoader会通过类似字典的方式读取CustomDataset中的数据,达到批量处理的效果。
import torch
from torch.utils.data import Dataset, DataLoader
# 定义数据集类
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return torch.tensor(sample) # 假设样本是一个列表,将其转换为张量返回
# 创建数据集实例
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
dataset = CustomDataset(data)
# 创建数据加载器实例
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 遍历数据加载器,批量读取数据
for batch in dataloader:
print(batch)
"""
tensor([[7, 8, 9],
[1, 2, 3]])
tensor([[4, 5, 6]])
"""