Python中CIFAR10的图像数据预处理

文章详细介绍了CIFAR10数据集在PyTorch中的使用,包括如何通过datasets.CIFAR10加载数据,以及ToTensor和Normalize等预处理方法的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 完整代码

data_dir = '../data/cifar/'
apply_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                               transform=apply_transform)

test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                              transform=apply_transform)

datasets.CIFAR10()是什么意思?

要理解这短短几行代码,可以从datasets.CIFAR10(data_dir, train=True, download=True, transform=apply_transform)入手。

datasets是包含了很多个数据集的模块,其中有CAFAR10数据集。

CIFAR10是datasets中的一个类,构造函数的参数如下:

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
  • root (str): CIFAR-10 数据集的存储位置。如果数据集还未下载,且 download=True,它将被下载到这个目录。如果数据集已经在这个目录中,就会直接从这个位置加载。

  • train (bool): 决定是否加载训练集。如果设置为 True,将加载训练集;如果设置为 False,将加载测试集。

  • transform (Optional[Callable]): 可选参数,允许用户传递一个预处理转换的函数或转换的组合。这些转换将在加载数据集时应用于每个图像。通过这个参数,你可以进行诸如图像裁剪、缩放、归一化等操作。

  • target_transform (Optional[Callable]): 这个可选参数允许用户传递一个转换函数,用于处理每个图像的目标标签(例如分类标签)。通过这个参数,你可以执行一些特定的目标转换或编码。

  • download (bool): 决定如果数据集不在指定的 root 位置,是否应该下载它。如果 download=True,数据集将被下载到指定目录。如果已经存在,则不会重新下载。

datasets.CIFAR10 是 PyTorch 中用于加载 CIFAR-10 数据集的函数。CIFAR-10 是一个常用的图像分类数据集,包括10个类别和60000个32x32彩色图像。以下是该函数中各个参数的解释:

  • data_dir(数据目录):该参数指定了CIFAR-10数据集的存储位置。如果数据集还未下载,它将被下载到这个目录。

  • train(训练):这个布尔参数决定了是否加载训练集。如果 train=True,则加载训练集;如果 train=False,则加载测试集。CIFAR-10数据集分为50000个训练图像和10000个测试图像。

  • download(下载):这个布尔参数决定了如果数据集不在data_dir指定的位置,是否应该下载它。如果 download=True,数据集将被下载到指定目录。如果已经存在,则不会重新下载。

  • transform(变换):该参数是可选的,允许用户传递一个预处理转换的函数或转换的组合。在加载数据集的同时,这些转换将被应用于每个图像。例如,你可以使用 transforms.Compose 来将多个转换组合在一起,例如将图像转换为张量,然后对其进行标准化。

这个函数返回一个 torch.utils.data.Dataset 对象,可以用于训练和测试模型。它方便地整合了数据的加载、转换和批处理等功能,使得与PyTorch模型的整合更为顺畅。

体现Python灵活性的transforms.Compose() 

transforms.Compose() 是 PyTorch 中 torchvision.transforms 模块的一个功能,用于组合多个图像变换操作。你可以将一系列的图像变换操作(每个操作都是可调用的对象)传递给 transforms.Compose(),然后它会按照它们在列表中的顺序来依次应用这些变换。

这非常有用,因为在预处理图像(例如,用于深度学习模型的训练)时,通常需要按照特定的顺序执行多个步骤。transforms.Compose() 允许你方便地捆绑这些步骤。

以下是一个更完整的示例,展示了如何使用 transforms.Compose() 来组合几个常见的图像变换操作:

from torchvision import transforms

# 定义一组要组合的图像变换
transform = transforms.Compose([
    transforms.Resize(256),              # 将图像缩放到 256x256
    transforms.RandomCrop(224),          # 在图像中随机裁剪 224x224 的区域
    transforms.ToTensor(),               # 将 PIL 图像或 NumPy ndarray 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化图像张量
                         std=[0.229, 0.224, 0.225])
])

# 然后,你可以将此组合的转换传递给数据加载器或直接应用于图像

通过使用 transforms.Compose(),你可以确保预处理步骤在整个数据集上保持一致,并且代码更加简洁易读。

强大的transforms.ToTensor()

使用transforms.ToTensor()后,每个PIL图像将转换为32x32x3的PyTorch张量,并且像素值的范围将从0-255映射到0-1的范围。这样的转换非常有用,因为许多神经网络期望输入在0-1范围内,并且预期的输入维度为[通道,高度,宽度]。这个转换将CIFAR-10数据整齐地格式化,使其可以直接用于大多数现代卷积神经网络。

排列顺序从HWC变成CHW有什么好处?

  1. 一致性:许多深度学习库和模型期望数据按照CHW格式组织。保持这种一致性有助于减少混乱,使得预处理和模型训练更流畅。

  2. 批量处理效率:当批量传递多个图像时,通道维度被组织在一起可以增加计算效率。例如,如果有一个批量大小为64的图像,每个图像的大小为3x32x32,则批量数据的大小将是64x3x32x32。这种组织方式有助于并行计算,特别是在卷积层中。

  3. 与预训练模型的兼容性:许多预训练的模型,如ResNet和VGG,都期望输入图像为CHW格式。使用CHW格式可以更容易地将图像传递给这些模型,而无需额外转换。

为什么要把PIL转换成张量?

PIL(Python Imaging Library)是一个用于打开、操作和保存许多不同图像文件格式的Python库。PIL现在是一个废弃的项目,但其后继者,Pillow库,继承了PIL并进行了改进和维护。PIL图像是PIL库处理的对象类型,用于代表图像。这个对象提供了许多方法,用于对图像进行操作,例如裁剪、旋转、调整大小、滤镜等。图像可以以各种模式存在,如RGB、L(灰度)等。

当你使用像datasets.CIFAR10这样的Torchvision数据集时,除非指定了转换,否则图像将作为PIL图像对象返回。但这个是整个图像,并不是能直接处理的数据(张量)。例如:

from PIL import Image

# 创建一个PIL图像对象
image = Image.open("example.jpg")

# 使用PIL的方法进行图像操作
image = image.rotate(45) # 旋转45度
image.show() # 显示图像

不可缺少的transforms.Normalize()

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

这表示对每个通道(例如RGB图像的R、G和B通道)减去0.5的均值,并除以0.5的标准差。这样可以确保每个通道的值都具有相似的尺度。

transforms.Normalize方法的工作原理是接收两个参数:均值(mean)和标准差(std)。这些参数被用来按照以下公式对图像进行归一化:

normalized value=(original value−mean)/std

在提供的示例中,均值和标准差都被设置为0.5。因此,如果原始图像的像素值范围在0到1之间,应用此归一化将如下所示:

  1. 减去均值0.5会使值在-0.5到0.5的范围内。
  2. 由于标准差也是0.5,所以将结果除以0.5将值的范围扩展到-1到1。

这种变换不仅改变了数据的范围,而且改变了数据的分布,使得网络可能更容易学习。通过将数据中心化和重新缩放,模型可能会更容易识别不同的特征和模式。

为什么ToTensor()和Normalize()两者都需要?

  • transforms.ToTensor()确保所有值都在同一尺度上,无论原始图像的亮度或对比度如何。
  • transforms.Normalize进一步确保每个通道的分布具有特定的均值和标准差,这对于训练稳定性和与预训练模型的兼容性可能是必要的。

通过这两个步骤,我们可以确保输入数据符合模型的预期,从而更容易地训练模型并实现更好的性能。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值