DataLoader批量读取数据

使用DataLoader的小例子,这里CustomDataset类的__getitem__方法需要返回tensor。
加载到DataLoader中之后,DataLoader会通过类似字典的方式读取CustomDataset中的数据,达到批量处理的效果。

import torch
from torch.utils.data import Dataset, DataLoader

# 定义数据集类
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return torch.tensor(sample)  # 假设样本是一个列表,将其转换为张量返回


# 创建数据集实例
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
dataset = CustomDataset(data)

# 创建数据加载器实例
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 遍历数据加载器,批量读取数据
for batch in dataloader:
    print(batch)
"""
tensor([[7, 8, 9],
        [1, 2, 3]])
tensor([[4, 5, 6]])
"""
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值