完整代码
data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_transform)
datasets.CIFAR10()是什么意思?
要理解这短短几行代码,可以从datasets.CIFAR10(data_dir, train=True, download=True, transform=apply_transform)入手。
datasets是包含了很多个数据集的模块,其中有CAFAR10数据集。
CIFAR10是datasets中的一个类,构造函数的参数如下:
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
-
root
(str): CIFAR-10 数据集的存储位置。如果数据集还未下载,且download=True
,它将被下载到这个目录。如果数据集已经在这个目录中,就会直接从这个位置加载。 -
train
(bool): 决定是否加载训练集。如果设置为True
,将加载训练集;如果设置为False
,将加载测试集。 -
transform
(Optional[Callable]): 可选参数,允许用户传递一个预处理转换的函数或转换的组合。这些转换将在加载数据集时应用于每个图像。通过这个参数,你可以进行诸如图像裁剪、缩放、归一化等操作。 -
target_transform
(Optional[Callable]): 这个可选参数允许用户传递一个转换函数,用于处理每个图像的目标标签(例如分类标签)。通过这个参数,你可以执行一些特定的目标转换或编码。 -
download
(bool): 决定如果数据集不在指定的root
位置,是否应该下载它。如果download=True
,数据集将被下载到指定目录。如果已经存在,则不会重新下载。
datasets.CIFAR10
是 PyTorch 中用于加载 CIFAR-10 数据集的函数。CIFAR-10 是一个常用的图像分类数据集,包括10个类别和60000个32x32彩色图像。以下是该函数中各个参数的解释:
-
data_dir
(数据目录):该参数指定了CIFAR-10数据集的存储位置。如果数据集还未下载,它将被下载到这个目录。 -
train
(训练):这个布尔参数决定了是否加载训练集。如果train=True
,则加载训练集;如果train=False
,则加载测试集。CIFAR-10数据集分为50000个训练图像和10000个测试图像。 -
download
(下载):这个布尔参数决定了如果数据集不在data_dir
指定的位置,是否应该下载它。如果download=True
,数据集将被下载到指定目录。如果已经存在,则不会重新下载。 -
transform
(变换):该参数是可选的,允许用户传递一个预处理转换的函数或转换的组合。在加载数据集的同时,这些转换将被应用于每个图像。例如,你可以使用transforms.Compose
来将多个转换组合在一起,例如将图像转换为张量,然后对其进行标准化。
这个函数返回一个 torch.utils.data.Dataset
对象,可以用于训练和测试模型。它方便地整合了数据的加载、转换和批处理等功能,使得与PyTorch模型的整合更为顺畅。
体现Python灵活性的transforms.Compose()
transforms.Compose()
是 PyTorch 中 torchvision.transforms
模块的一个功能,用于组合多个图像变换操作。你可以将一系列的图像变换操作(每个操作都是可调用的对象)传递给 transforms.Compose()
,然后它会按照它们在列表中的顺序来依次应用这些变换。
这非常有用,因为在预处理图像(例如,用于深度学习模型的训练)时,通常需要按照特定的顺序执行多个步骤。transforms.Compose()
允许你方便地捆绑这些步骤。
以下是一个更完整的示例,展示了如何使用 transforms.Compose()
来组合几个常见的图像变换操作:
from torchvision import transforms
# 定义一组要组合的图像变换
transform = transforms.Compose([
transforms.Resize(256), # 将图像缩放到 256x256
transforms.RandomCrop(224), # 在图像中随机裁剪 224x224 的区域
transforms.ToTensor(), # 将 PIL 图像或 NumPy ndarray 转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化图像张量
std=[0.229, 0.224, 0.225])
])
# 然后,你可以将此组合的转换传递给数据加载器或直接应用于图像
通过使用 transforms.Compose()
,你可以确保预处理步骤在整个数据集上保持一致,并且代码更加简洁易读。
强大的transforms.ToTensor()
使用transforms.ToTensor()
后,每个PIL图像将转换为32x32x3的PyTorch张量,并且像素值的范围将从0-255映射到0-1的范围。这样的转换非常有用,因为许多神经网络期望输入在0-1范围内,并且预期的输入维度为[通道,高度,宽度]。这个转换将CIFAR-10数据整齐地格式化,使其可以直接用于大多数现代卷积神经网络。
排列顺序从HWC变成CHW有什么好处?
-
一致性:许多深度学习库和模型期望数据按照CHW格式组织。保持这种一致性有助于减少混乱,使得预处理和模型训练更流畅。
-
批量处理效率:当批量传递多个图像时,通道维度被组织在一起可以增加计算效率。例如,如果有一个批量大小为64的图像,每个图像的大小为3x32x32,则批量数据的大小将是64x3x32x32。这种组织方式有助于并行计算,特别是在卷积层中。
-
与预训练模型的兼容性:许多预训练的模型,如ResNet和VGG,都期望输入图像为CHW格式。使用CHW格式可以更容易地将图像传递给这些模型,而无需额外转换。
为什么要把PIL转换成张量?
PIL(Python Imaging Library)是一个用于打开、操作和保存许多不同图像文件格式的Python库。PIL现在是一个废弃的项目,但其后继者,Pillow库,继承了PIL并进行了改进和维护。PIL图像是PIL库处理的对象类型,用于代表图像。这个对象提供了许多方法,用于对图像进行操作,例如裁剪、旋转、调整大小、滤镜等。图像可以以各种模式存在,如RGB、L(灰度)等。
当你使用像datasets.CIFAR10
这样的Torchvision数据集时,除非指定了转换,否则图像将作为PIL图像对象返回。但这个是整个图像,并不是能直接处理的数据(张量)。例如:
from PIL import Image
# 创建一个PIL图像对象
image = Image.open("example.jpg")
# 使用PIL的方法进行图像操作
image = image.rotate(45) # 旋转45度
image.show() # 显示图像
不可缺少的transforms.Normalize()
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
这表示对每个通道(例如RGB图像的R、G和B通道)减去0.5的均值,并除以0.5的标准差。这样可以确保每个通道的值都具有相似的尺度。
transforms.Normalize
方法的工作原理是接收两个参数:均值(mean)和标准差(std)。这些参数被用来按照以下公式对图像进行归一化:
normalized value=(original value−mean)/std
在提供的示例中,均值和标准差都被设置为0.5。因此,如果原始图像的像素值范围在0到1之间,应用此归一化将如下所示:
- 减去均值0.5会使值在-0.5到0.5的范围内。
- 由于标准差也是0.5,所以将结果除以0.5将值的范围扩展到-1到1。
这种变换不仅改变了数据的范围,而且改变了数据的分布,使得网络可能更容易学习。通过将数据中心化和重新缩放,模型可能会更容易识别不同的特征和模式。
为什么ToTensor()和Normalize()两者都需要?
transforms.ToTensor()
确保所有值都在同一尺度上,无论原始图像的亮度或对比度如何。transforms.Normalize
进一步确保每个通道的分布具有特定的均值和标准差,这对于训练稳定性和与预训练模型的兼容性可能是必要的。
通过这两个步骤,我们可以确保输入数据符合模型的预期,从而更容易地训练模型并实现更好的性能。