#%%
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import sys
batch_size = 2
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081, ))
])
#train = true是训练集,false为测试集
MNIST_dataset_train = datasets.MNIST(root='./data/mnist', train=False, download=False, transform=transform)
dataloaders_train = DataLoader(dataset=MNIST_dataset_train, batch_size=batch_size, shuffle=True)
#%% 训练集数据60000张,每次循环datasets,输出x,y;x为N*1*28*28的图像,y为1*N的label
#测试集数据10000张,10000个标签
i = 0
for x,y in dataloaders_train:
#获取一张图片,和一个图片的标签
if i==0:
print('label:',y[0])
plt.imshow(x[0,0,:,:])
plt.pause(0.001)
else:
sys.exit()
i+=1
label: tensor(3)
![](https://img-blog.csdnimg.cn/20210126130406336.png)