DataLoader类

`DataLoader` 类是 PyTorch 中用于构建数据加载器的一个重要工具,它可以对数据集进行批处理、洗牌和并行加载,以便于训练神经网络模型。

### 输入参数:
- **dataset**:数据集对象,通常是 `torch.utils.data.Dataset` 类的子类对象,用于包装需要加载的数据。
- **batch_size**:每个批次中包含的样本数量。
- **shuffle**:一个布尔值,表示是否在每个 epoch 前洗牌数据。
- **num_workers**:用于数据加载的子进程数量。
- **collate_fn**:用于自定义批处理方式的函数,通常在需要对每个批次进行一些自定义处理时使用。
- **drop_last**:一个布尔值,表示是否丢弃最后一个不完整的批次,当数据总数不能被 batch_size 整除时使用。

### 输出:
`DataLoader` 对象,可以通过迭代器的方式逐批次地加载数据,每个批次的数据以字典或元组的形式返回。下面是一个简单的示例:
```python

import torch
from torch.utils.data import Dataset, DataLoader

# 定义自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 创建数据集对象
data = [i for i in range(100)]
dataset = MyDataset(data)

# 创建 DataLoader 对象
batch_size = 10
shuffle = True
num_workers = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

# 迭代加载数据
for step, batch in enumerate(train_dataloader):


```

在这个示例中,`batch` 是一个由数据组成的张量,它的大小为 `[batch_size]`。根据需要,你可以对 `collate_fn` 进行自定义来改变输出的形式。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值