import torchvision
from torch.utils.data import DataLoader
#准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader (dataset=test_data,
batch_size=4,
shuffle=True,
num_workers=0,
drop_last=False)
#测试数据集中第一张图片及label
img,label= test_data[0]
print(img.shape)
print(label)
结果: 3通道,32*32,label=3
for data in test_loader:
imgs,labels = data
print(imgs.shape)
print(labels)
部分结果:
4张图片,3通道,32*32
4张图片的标签进行打包