1. torchvision数据集介绍
- torchvision中有很多数据集,当我们写代码时指定相应的数据集指定一些参数,它就可以自行下载。
- CIFAR-10数据集包含60000张32×32的彩色图片,一共10个类别,其中50000张训练图片,10000张测试图片。
2. 查看CIFAR10数据集内容
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) # root为存放数据集的相对路线
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True) # train=True是训练集,train=False是测试集
print(test_set[0]) # 输出的3是target
print(test_set.classes) # 测试数据集中有多少种
img, target = test_set[0] # 分别获得图片、target
print(img)
print(target)
print(test_set.classes[target]) # 3号target对应的种类
img.show()
备注:如果下载过慢可以复制python控制台中的链接到迅雷下载,将文件复制粘贴到root目录下即可
3. Tensorboard查看内容
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True) # 将ToTensor应用到数据集中的每一张图片,每一张图片转为Tensor数据类型
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
writer = SummaryWriter("logs")
for i in range(11):
img, target = test_set[i]
writer.add_image("test_set",img,i)
print(img.size())
writer.close() # 一定要把读写关闭,否则显示不出来图片
随后在Anaconda终端激活虚拟环境,输入tensorboard --logdir="C:\Users\Asabopp\Desktop\learn_pytorch\logs" 命令,将网址赋值浏览器的网址栏,回车,即可查看tensorboard显示日志情况。