DataLoader的使用
torch.utils.data.DataLoader
形象理解:
-
dataset:一副扑克
-
dataloader:抽牌方式
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, ***, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[SOURCE]
常用参数的通俗解释
-
dataset: 自定义数据集
-
batchsize: 每次抽牌抽几张,默认为 1
-
shuffle: 每局牌局前是否洗牌(牌堆的顺序是否一样),一般设置为 True
num_workers: 加载数据时采用的进程数量,默认为 0
但是在windows操作系统下设置为大于0的值时可能会出现问题:
"BrokenPipeError"
-
drop_last: 牌堆里有100张牌,每次取7张的话则到最后一定会剩余2张,设置为True则为舍弃之
测试代码
import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # 准备的测试数据集 test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=False , transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False) # 按住ctrl点击CIFAR10查看源文件中的getitem()方法,发现返回值类型为:img, target # 测试数据集中第一个样本图片及其对应的target img, target = test_data[0] print(img.shape) # torch.Size([3, 32, 32]) print(target) # 3 # 测试加载之后的数据集 # 对比之前的输出值即可理解之 for data in test_loader: imgs, targets = data print(imgs.shape) # torch.Size([4, 3, 32, 32]) # 注意此处输出的第一个样本target值为 5 ,不同于前面的 3 # 这是因为dataloader中的 sampler 为torch.utils.data.sampler.RandomSampler # 说明每次从牌堆中取出的 4 张牌是随机取的! print(targets) # tensor([5, 4, 3, 7]) test_loader2 = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False) writer = SummaryWriter("logs") step = 0 for epoch in range(2): # shuffle为False,则 2 轮牌堆中的牌顺序是相同的 for imgs, targets in test_loader2: writer.add_images(f"Epoch:{epoch}", imgs, step) step += 1 writer.close()