7.1 数据集下载及操作
Pytorch官网 link Docs -> torchvision 提供有各种数据集的API文档。
以 CIFAR10 为例。
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True) # "."参照当前文件所在目录
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
运行后
# 下载慢时可复制地址至迅雷进行下载
进行断点测试,可知各个真实类别对应数字
进一步验证
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True) # "."参照当前文件所在目录
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0]) # 打印 test 数据集的第一张图片的所有信息
print(test_set.classes) # 打印数据集的所有类别信息
img, target = test_set[0]
print(img) # 图片信息
print(target) # 类别信息
print(test_set.classes[target])
img.show()
运行结果
7.2 结合Transforms
将所有图片转化成 tensor 类型并验证
import torchvision
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0]) # 打印 test 数据集的第一张图片的所有信息
# print(test_set.classes) # 打印数据集的所有类别信息
#
# img, target = test_set[0]
# print(img) # 图片信息
# print(target) # 类别信息
# print(test_set.classes[target])
# img.show()
print(test_set[0])
使用 TensorBoard 显示
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0]) # 打印 test 数据集的第一张图片的所有信息
# print(test_set.classes) # 打印数据集的所有类别信息
#
# img, target = test_set[0]
# print(img) # 图片信息
# print(target) # 类别信息
# print(test_set.classes[target])
# img.show()
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
![](https://img-blog.csdnimg.cn/1bcc08e33c6c4fabbc133d9a5723728c.png#pic_center)