04 - torchvision中的数据集使用
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
注意download=True可以自动下载数据集。接下来联动transforms和Tensorboard。
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 联动transforms Compose
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 导入数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=dataset_transform)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True, transform=dataset_transform)
# img, target = test_set[0]
# img.show()
writer = SummaryWriter("p10")
for i in range(10):
img, target = train_set[i]
writer.add_image("test_set_image", img, i)
writer.close()