加载数据、print
CIFAR10是torchvision自带的数据集。
import torchvision
#root:数据集存放的位置,train默认是True(True下载训练集,False下载测试集),download:从网上自动下载
tarin_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
# 第一张img
print(test_set[0])
# 种类
print(test_set.classes)
img, target = test_set[0]
print(img)
# 对应的是猫
print(target)
print(test_set.classes[target])
img.show()
和transforms联动
原始图片是PIL.Image,需要转成tensor类型。
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# transform=dataset_transform:把ToTensor应用到数据集的每一张图片
tarin_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()