一、使用torchvision给我们提供的标准数据集
import torchvision
train_set = torchvision.datasets.CIFAR10(root=r"./dataset", train=True, download=True)#在文件里下载一个命名为dataset的文件
test_set = torchvision.datasets.CIFAR10(root=r"./dataset", train=False, download=True)#测试数据集
#看dataset中的第一个数据集
print(test_set[0])
print(test_set.classes)
img,target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img_show()
print(test_set[0])的结果
解析:(输入图片,target)
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=Flase, download=True)
二、与transforms结合使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=r"./dataset", train=True, transform=dataset_transform, download=True)#在文件里下载一个命名为dataset的文件
test_set = torchvision.datasets.CIFAR10(root=r"./dataset", train=False, transform=dataset_transform, download=True)#测试数据集
#看dataset中的第一个数据集
#print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
#terminal-输入tensorboard --logdir="p10"
#打开链接
小提示:
pytorch官网-docs-torchvision