pytorch——torchvision数据集
1.Datasets(以Datesets中的CIFAR10数据集为例)
基本操作:
import torchvision
train_set=torchvision.datasets.CIFAR10('./dataset',train=True,download=True)
test_set=torchvision.datasets.CIFAR10('./dataset',train=False,download=True)
print(test_set[0])
print(test_set.classes)
img,target=test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
利用transforms对数据集中每一张图片进行变换
import torchvision
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set=torchvision.datasets.CIFAR10('./dataset',train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10('./dataset',train=False,transform=dataset_transform,download=True)
print(test_set[0])
将test_set中图片写入tensorboard
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set=torchvision.datasets.CIFAR10('./dataset',train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10('./dataset',train=False,transform=dataset_transform,download=True)
writer=SummaryWriter('p10')
for i in range(10):
img,target=test_set[i]
writer.add_image('test_set',img,i)
writer.close()