原文连接:PyTorch DataLoader: A Complete Guide • datagy
1、理解dataloader类
# Understanding the PyTorch DataLoader Class
from torch.utils.data import DataLoader
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False
)
2、创建和使用PyTorch dataloader类
# Loading the MNIST Dataset Using PyTorch
# Importing Libraries
from torchvision.datasets import MNIST
# Downloading and Saving MNIST
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
# Accessing a Dataset Item
print(data_train[0])
# Returns:
# (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
# 0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
可视化一个例子:
# Visualizing a Sample
import matplotlib.pyplot as plt
plt.imshow(data_train.data[0])
plt.show()
输出样例如下:
加载数据集,然后创建自己的dataloader
# Creating a Training DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
# Downloading and Saving MNIST
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
# Creating Data Loader
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)
print(data_loader)
# Returns:
# <torch.utils.data.dataloader.DataLoader object at 0x7fc3c021b6d0>
3、在PyTorch dataloader上迭代
# Loading the First Batch and Printing Information
for idx, batch in enumerate(data_loader):
print('Batch index: ', idx)
print('Batch size: ', batch[0].size())
print('Batch label: ', batch[1])
break
# Returns:
# Batch index: 0
# Batch size: torch.Size([20, 1, 28, 28])
# Batch label: tensor([3, 3, 7, 7, 2, 4, 7, 2, 1, 8, 3, 3, 9, 3, 2, 3, 5, 0, 6, 8])
4、在PyTorch dataloader上加载数据和目标
# Accessing Data and Targets in a PyTorch DataLoader
for idx, (data, target) in enumerate(data_loader):
print(data[0])
print(target[0])
break
# Returns:
# tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
# 0.0000, 0.0000, 0.0000, 0.0000]]])
# tensor(1)
5、用PyTorch dataloader加载数据到GPU(CUDA)
# Loading Data to a GPU with a PyTorch DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for idx, (data, target) in enumerate(data_loader):
data = data.to(device)
target = target.to(device)