import torchvision
from torch.utils.tensorboard import SummaryWriter
1.数据集下载的是PIL格式,把他转换成tensor
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
2.数据集下载
root代表数据集存放的目录
train如果是True为训练集,为False是测试集
transform 进行格式转换
download 为True进行数据集的下载,如果存在则不下载
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
3.img, target = train_set[0]的说明
test_data[i]调用的 是getitem() return img,target
train_set[0]的结果是 { 图片信息,位置}
所以img得到图片信息,target得到位置
4.进行图片显示,采用迭代器的循环
writer = SummaryWriter('r10')
for i in range(10):
img, target = train_set[i]
writer.add_image("P10", img, i)
writer.close()