课程链接: https://www.bilibili.com/video/BV1hE411t7RN?p=14&vd_source=a16915472897bc5c811d5ff185570c98
课堂笔记
使用CIFAR10数据集,以单个图片为例
先下载数据集
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0])
输出:
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1B339186100>, 3)
在print(test_set[0])处打一个断点,点击debug
可以看到test_set下面有个classes
把属性classes打印出来:
print(test_set.classes) #打印完整的classes列表
print(test_set.classes[1]) #打印第2个class
输出结果:
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
automobile
把test_set[0]的img和target赋给变量img和target
img, target = test_set[0]
print(img)
print(target)
输出的是:
<PIL.Image.Image image mode=RGB size=32x32 at 0x1E7139E7160>
3
在classes列表中找第3+1个,可知这张照片是cat
打印一下看看
print(test_set.classes[target])
输出:
cat
使用CIFAR10数据集,与transforms联动,以ToTensor为例,多图
import torchvision
# 与transforms联动
# 批量处理训练集和测试集,把原本的PIL image转换成Tensor image
# 以适应后续深度学习中的函数对图片类型的要求
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
]) # 创建一个名为dataset_transform的工具,功能是把PIL image转换成Tensor image
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# 上两行中的transform=dataset_transform是使用工具dataset_transform
# 把训练集和测试集中的PIL image转换成Tensor image
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("p14")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
Terminal中运行tensorboard --logdir="p14"打开tensorboard
可以看到一共10步,拖动滑块一次展示10个图片