1 参数
大部分有默认值,实际中只需要设置少量的参数即可
以扑克牌举例
- batch_size:每次抓牌抓几张
- shuffle:是否打乱,值为True的话两次打牌时牌的顺序是不一样。默认为False,一般用True
- num_workers:加载数据时采用单个进程还是多个进程,多进程的话速度相对较快,默认为0(主进程加载)。Windows系统下该值>0会有问题(报错提示:BrokenPipeError)
- drop_last:100张牌每次取3张,最后会余下1张,剩下的这张牌是否舍去。值为True代表舍去这张牌、不取出,False代表要取出该张牌
2 使用
import torchvision
from torch.utils.data import DataLoader
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="./datasets",train=False,transform=dataset_transform,download=True)
test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
# 测试数据集中第一张图片及target
img,target = test_data[0]
print(img.shape)
print(target)
for data in test_loader:
imgs,targets = data
print(imgs.shape)
print(targets)
Tensorboard展示
1、batch_size为64,drop_last为False
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="./datasets",train=False,transform=dataset_transform,download=True)
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
writer = SummaryWriter("log2")
step = 0
for data in test_loader:
imgs,targets = data
writer.add_images("test_data",imgs,step)
step += 1
writer.close()
由于 drop_last 设置为 False,所以最后16张图片(没有凑齐64张)
2、shuffle为True
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
test_data = torchvision.datasets.CIFAR10(root="./datasets",train=False,transform=dataset_transform,download=True)
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
writer = SummaryWriter("log2")
for epoch in range(2):
step = 0
for data in test_loader:
imgs,targets = data
writer.add_images("Epoch:{}".format(epoch),imgs,step)
step += 1
writer.close()
一个 for data in test_loader 循环,就意味着打完一轮牌(抓完一轮数据),shuffle值为True的话,下一轮会重新洗牌(一般都设置为True),可以看出两次图片不一样