目录
步骤1:数据准备
首先,我们需要准备一个包含各种类别图像的数据集。对于本示例,我们将使用一个公开可用的图像分类数据集,如ImageNet、CIFAR-10或自定义数据集。在这里,我们将使用CIFAR-10数据集作为示例,它包含10个不同的类别,包括飞机、汽车、鸟类、猫、狗、青蛙、马、船、卡车和自行车。
你可以在PyTorch中使用torchvision
库轻松加载CIFAR-10数据集:
import torchvision
import torchvision.transforms as transforms
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, dow