`torch.utils.data.DataLoader`是PyTorch中用于加载数据的工具,它可以将数据集按照batch size打包成一个一个的小批量数据,方便模型进行训练。下面是`torch.utils.data.DataLoader`的介绍:
```python
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
```
参数说明:
- `dataset`:数据集,必须是`torch.utils.data.Dataset`的子类。
- `batch_size`:每个batch中的样本数,默认为1。
- `shuffle`:是否打乱数据集,默认为False。
- `sampler`:自定义的采样器,如果指定了`sampler`,则`shuffle`必须为False。
- `batch_sampler`:自定义的batch采样器。
- `num_workers`:用于数据加载的子进程数,默认为0,表示在主进程中加载数据。
- `collate_fn`:用于将样本列表转换为mini-batch张量的函数,默认使用`default_collate`函数。
- `pin_memory`:是否将数据加载到CUDA固定内存中,默认为False。
- `drop_last`:如果数据集大小不能被batch size整除,则是否丢弃最后一个不完整的batch,默认为False。
- `timeout`:数据加载器等待数据的超时时间,默认为0,表示不等待。
- `worker_init_fn`:每个worker初始化函数。
- `prefetch_factor`:预取因子,用于预取数据,默认为2。
- `persistent_workers`:是否使用持久化的worker进程,默认为False。
下面是一个使用`torch.utils.data.DataLoader`的例子:
```python
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载MNIST数据集
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
# 定义数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
# 迭代数据加载器
for batch_idx, (data, target) in enumerate(train_loader):
# 训练模型
pass
```