torchvision.datasets.CIFAR10()
是PyTorch中用于加载CIFAR-10数据集的函数,该数据集包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。其中50000张图像用于训练,10000张图像用于测试。下面是torchvision.datasets.CIFAR10()
函数中各个参数的详细介绍:
-
root(字符串类型):
- 作用:指定数据集下载后存放的本地路径。
- 示例:
root='./data'
,表示数据集将被下载并存储在当前目录下的data
文件夹中。
-
train(布尔类型):
- 作用:指定是否加载训练集。
- 取值:
True
表示加载训练集,False
表示加载测试集。 - 示例:
train=True
,表示加载CIFAR-10的训练集。
-
transform(可选,函数类型):
- 作用:指定对数据集中的图像进行变换的函数或操作。
- 常见变换:包括转换为张量(
ToTensor()
)、归一化(Normalize()
)、裁剪(RandomCrop()
)等。 - 示例:
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
,表示先将图像转换为张量,然后进行归一化处理。
-
target_transform(可选,函数类型):
- 作用:指定对数据集中的标签进行变换的函数。
- 示例:通常较少使用,因为CIFAR-10的标签已经是整数形式,可以直接用于分类任务。
-
download(布尔类型):
- 作用:指定是否从互联网上下载数据集。
- 取值:
True
表示如果数据集尚未下载,则从互联网下载;如果数据集已经下载,则不会重新下载。False
表示不下载数据集。 - 示例:
download=True
,表示如果数据集尚未下载,则进行下载。
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练集
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# 加载测试集
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)