P15.DataLoader的使用
Dataset一摞牌
DataLoader抓几张牌,以什么样的方式抓牌
Pytorch官网 -> Docs -> Pytorch -> 在左边的搜索框直接搜索dataloader
class
torch.utils.data.
DataLoader
(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
Parameters
dataset (Dataset) – dataset from which to load the data.将之前自定义的dataset实例化,放入DataLoader中
batch_size (int, optional) – how many samples per batch to load (default:
1
).每次抓几张牌shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default:False
).洗牌,True:第二次牌的顺序和第一次不同num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (default:0
)多进程,加载数据用多进程速度较快,num_workers默认为0表示采用主进程加载数据drop_last (bool, optional) – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False
)有100张牌,每次取3张牌,余1张牌,True:舍去最后一张牌
import torchvision
# prepare test dataset
from torch.utils.data import DataLoader
test_data = torchvision.datasets.CIFAR10(root="dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# the first picture and target in the test dataset
img, target = test_data[0]
print(img.shape)
print(target)
torch.Size([3, 32, 32])
3
import torchvision
# prepare test dataset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# the first picture and target in the test dataset
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()