DataLoader有很多参数,但常用的有下面五个:
- dataset表示Dataset类,传入读取的dataset即可;
- batch_size表示一次读取多少数据
- num_works表示是否多进程读取数据
- shuffle表示每个epoch是否乱序
- drop_last表示当样本数不能被batch_size整除时,是否舍弃最后一批数据
示例:
导入:
from torch.utils.data import DataLoader
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train = torchvision.datasets.CIFAR10(root='./dataset', train=True,
transform=dataset_transform, download=False)
test = torchvision.datasets.CIFAR10(root='./dataset', train=False,
transform=dataset_transform, download=False)
test_loader = DataLoader(dataset=test, batch_size=64, shuffle=True,
num_workers=0, drop_last=False)
看看Dataloder的输出
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
截取了前两次batch,可以看到他是一个64维的tensor
输出到TensorBoard中看下
writer = SummaryWriter("logs")
step = 0
for epoch in range(2):
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step = step + 1
writer.close()
这里设置shuffle为False可以看到两次epoch都一样
左图drop_last=False,右图drop_last=True