DataLoader会将Dataset的数据转换为Tensor类型

工作原理

  1. Dataset: Dataset 是自己定义的数据集类,它需要实现 __len____getitem__ 方法。在 __getitem__ 方法中,你通常返回数据项(例如 numpy 数组、列表、字典等)。

  2. DataLoader: DataLoaderDataset 中取出数据,并自动处理批处理(batching)、打乱(shuffling)、多进程加载(multi-process loading)等操作。

  3. Collate 函数: DataLoader 使用 collate 函数来将单个数据项合并成一个批次。默认的 collate 函数会将 numpy 数组、Python 列表和其他支持的类型自动转换为 PyTorch 张量。

示例

以下是一个完整的示例,展示了如何使用 DataLoaderDataset 中加载数据并自动转换为张量。

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据
        self.data = np.random.randn(100, 3)  # 100个样本,每个样本3个特征
        self.labels = np.random.randint(0, 2, size=(100,))  # 100个样本,每个样本一个标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 返回数据项(数据和标签)
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集实例
dataset = MyDataset()

# 创建DataLoader实例
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 遍历DataLoader
for batch_data, batch_labels in dataloader:
    print(f"数据类型: {type(batch_data)}, 形状: {batch_data.shape}")
    print(f"标签类型: {type(batch_labels)}, 形状: {batch_labels.shape}")
    break  # 只展示一个批次的数据

输出

数据类型: <class 'torch.Tensor'>, 形状: torch.Size([10, 3])
标签类型: <class 'torch.Tensor'>, 形状: torch.Size([10])

详细说明

  1. 自定义数据集: MyDataset 继承自 torch.utils.data.Dataset,实现了 __len____getitem__ 方法。在 __getitem__ 方法中,我们返回了数据项(numpy 数组)和标签。

  2. DataLoader: 创建了一个 DataLoader 实例,设置 batch_size=10shuffle=True

  3. 批处理: 在遍历 DataLoader 时,每个批次的数据和标签会被自动转换为 PyTorch 张量。

自定义 collate 函数

如果需要自定义数据项的合并方式,可以提供自己的 collate 函数:

def my_collate_fn(batch):
    data, labels = zip(*batch)
    data = torch.tensor(data)
    labels = torch.tensor(labels)
    return data, labels

dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=my_collate_fn)

通过这种方式,可以灵活地控制数据加载和转换过程。如果 DataLoader 默认的行为无法满足需求,提供自定义 collate 函数是一个不错的解决方案。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值