import torchvision
from tensorboardX import SummaryWriter
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
# compose是把一些变化操作组合在一起,这里除了转化成tensor以外还可以添加裁剪等处理
# CIFAR数据集中的数据是PIL,需要转为tensor才能输入到writer
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
# 如果数据集已存在设置成download=True也不会再下载一遍 所以一直True就完事了
#如果下载的很慢可以复制网址到迅雷下载,下完直接把tar.gz包复制到dataset文件夹下面
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
#train=True是训练集,false就是测试集
print(test_set[0])
#可以去官网查看test_set的输出类型,也可以print(test_set)查看,这里test_set的输出由img和label的索引组成 test_set[0]是查看第一个图片的tensor信息以及类别索引
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
#终端输入:tensorboard --logdir=p10 --port=9999
writer.close()
【PyTorch笔记】pytorch入门教程5pytorch的数据库
最新推荐文章于 2024-09-21 22:47:11 发布