1.引言
1.1. MobileViT是什么?
MobileViT是一种基于Transformer的轻量级视觉模型,专为移动端设备上的图像分类任务而设计。
- 背景与目的:
- MobileViT由Google在2021年提出,旨在解决移动设备上的实时图像分类需求。
- 与传统的卷积神经网络(CNN)相比,MobileViT在保持高性能的同时,显著降低了计算复杂度和内存需求,从而更适应移动设备的计算能力。
- 技术特点:
- 轻量级与移动友好:MobileViT通过引入轻量级的Transformer模块和有效的降维策略,大幅减少了模型的参数数量和计算复杂度,使其能够在移动设备上高效运行。
- 基于Transformer:MobileViT采用了Transformer架构,通过自注意力机制捕获图像的全局上下文信息,提高了模型的泛化能力和准确性。
- 优化方法:MobileViT采用了一系列优化方法,如混合精度训练和自适应模型调整等,进一步提高了在移动设备上的运行效率。
- 性能表现:
- 在多个图像分类数据集上,MobileViT均取得了与现有轻量级CNN模型相当或更优的性能。例如,在ImageNet-1k数据集上,MobileViT在大约600万个参数的情况下达到了78.4%的Top-1准确率。
- MobileViT显示出更好的泛化能力,即使在使用大量数据增强的情况下,也能更好地预测未知数据集上的表现。
- 与其他基于Transformer的模型相比,MobileViT对超参数的调整相对健壮,对L2正则化等超参数的敏感度较低。
- 适用场景:MobileViT特别适用于需要实时图像分类的移动端应用,如智能手机、平板电脑等。其轻量级和高效的特点使得它成为移动视觉任务中的理想选择。
综合而言,MobileViT通过结合轻量级的Transformer架构和优化的设计,成功实现了在移动设备上高效、准确的图像分类。其优秀的性能、泛化能力和对超参数的鲁棒性使得它在移动视觉领域具有广泛的应用前景。
1.2.Transformer架构的特点
Transformer架构最初由谷歌大脑在2017年的论文《Attention Is All You Need》中提出,是一种基于自注意力机制的序列到序列(Seq2Seq)模型。自提出以来,该模型在自然语言处理(NLP)和计算机视觉(CV)等领域取得了显著的成功,并多次达到该领域内的最佳效果(SOTA)。
- 核心思想
Transformer架构的核心思想是使用自注意力机制(self-attention mechanism)来建立输入序列的表示。相比于传统的循环神经网络(RNN)架构,Transformer能够并行地处理整个序列,而不是按顺序逐步处理,从而提高了计算效率。
- 架构组成
Transformer架构主要由两个主要组件组成:编码器(Encoder)和解码器(Decoder)。
-
编码器(Encoder):
-
主要负责将输入序列转化为一种中间表示形式,这种表示形式能够捕捉输入序列中的上下文信息。
-
编码器由多个相同的层堆叠而成,每个层都包含自注意力机制和前馈神经网络(Feed-Forward Neural Network)。
-
自注意力机制允许模型在序列内的任意位置间直接建立依赖,从而更好地理解数据的上下文关系。
-
位置编码(Positional Encoding)用于提供关于单词在序列中位置的信息,因为Transformer不使用基于顺序的结构。
-
解码器(Decoder):
-
主要负责根据编码器的输出和之前的解码输出,生成新的序列。
-
解码器同样由多个相同的层堆叠而成,其结构与编码器类似,但还包含了一个额外的自注意力层和一个编码器-解码器注意力层。
-
编码器-解码器注意力层允许解码器关注编码器输出的不同位置,从而帮助生成准确的输出序列。
- 特点与优势
- 并行处理能力:Transformer能够并行地处理整个序列,而不是像RNN那样按顺序逐步处理,这大大提高了计算效率。
- 长距离依赖建模能力:通过自注意力机制,Transformer能够建模输入序列中的长距离依赖关系,这在处理长序列时尤为重要。
- 多头注意力机制:Transformer采用多头注意力机制,允许模型同时学习数据的不同表示,每个“头”关注序列的不同部分,这有助于模型捕捉更丰富的信息。
- 灵活性:Transformer架构非常灵活,可以应用于各种序列生成任务,如机器翻译、文本摘要、语音识别等。
综合而言,Transformer架构在自然语言处理领域特别流行,例如BERT和GPT等预训练语言模型就是从Transformer中衍生出来的。此外,Transformer架构也被广泛应用于计算机视觉领域,如图像分类、目标检测等任务。在智能驾驶领域,Transformer架构也被用于感知、预测和决策等各个环节。
1.3. 研究内容
在本文的例子中,我们将介绍并实现MobileViT架构,该架构是由Mehta等人提出的,它融合了Transformer(由Vaswani等人开创)和卷积神经网络的优点。通过Transformer,MobileViT能够捕获图像中的长距离依赖关系,从而生成全局表示;而卷积操作则帮助模型捕捉图像中的局部空间关系。
MobileViT的设计不仅结合了Transformer和卷积的特性,还作为一个通用且移动友好的骨干网络,适用于各种图像识别任务。据研究结果显示,在性能方面,MobileViT相比其他复杂度相近或更高的模型(如MobileNetV3)具有优势,同时保持了在移动设备上的高效运行。
请注意,为了成功运行这个示例,您需要安装TensorFlow 2.13或更高版本。
1.4. 研究意义
随着移动设备应用的广泛普及,图像分类等计算机视觉任务在移动设备上的需求日益增长。然而,传统的深度学习模型,特别是基于卷积神经网络(CNN)的模型,往往面临着计算资源和存储需求的限制,难以在移动设备上高效运行。因此,开发轻量级、高效的深度学习模型成为了一个迫切的研究需求。
MobileViT模型通过融合Transformer和卷积神经网络的优势,为解决移动设备上的图像分类问题提供了新的思路。它利用Transformer的自注意力机制捕捉图像中的长距离依赖关系,同时结合卷积操作捕捉局部空间关系,从而在保持高性能的同时降低了计算复杂度和内存需求。相比传统的轻量级CNN模型,MobileViT在多个图像分类数据集上均取得了优异的性能,证明了其在移动设备图像分类任务中的有效性和实用性。
MobileViT的研究不仅具有理论价值,还具有重要的实际应用前景。它能够为移动设备上的实时图像处理任务提供高效的解决方案,为用户带来更好的使用体验。随着移动设备性能的不断提升和计算资源的持续优化,MobileViT有望在更多领域得到应用,推动移动设备上的计算机视觉技术向前发展。同时,MobileViT的研究也为其他轻量级深度学习模型的设计和优化提供了有益的参考。
2. 部署MobileViT
2.1.设置
2.1.1.导入函数库
# 导入必要的库
import os
import tensorflow as tf
# 设置Keras的后端为TensorFlow(虽然Keras现在默认后端就是TensorFlow,但这里显式设置以确保环境配置正确)
os.environ["KERAS_BACKEND"] = "tensorflow"
# 导入Keras库以及相关的layers和backend模块
import keras
from keras import layers
from keras import backend as K
# 导入tensorflow_datasets库,用于加载数据集
import tensorflow_datasets as tfds
# 禁用tensorflow_datasets在加载数据时的进度条显示,以避免在输出中显示额外的进度信息
tfds.disable_progress_bar()
2.2.2.设置超参数
# 这些值来自表4。
patch_size = 4 # 2x2,用于Transformer块。
image_size = 256 # 输入图像的尺寸。
expansion_factor = 2 # MobileNetV2块的扩展因子。
这段代码定义了三个变量,分别用于设置Transformer块的Patch大小、输入图像的尺寸以及MobileNetV2块的扩展因子。这些参数对于构建MobileViT模型是必要的。
2.2.构建MobileViT
MobileViT架构是一个专为移动设备设计的图像分类模型,它巧妙地结合了Transformer和卷积神经网络的优点,以实现高效且准确的图像识别。
1. 输入处理
在模型的初始阶段,输入图像首先通过一系列带步长的3x3卷积层进行处理。这些卷积层不仅用于提取图像的初步特征,还通过调整步长来逐步降低特征图的分辨率,从而减少后续层的计算量。
2. MobileNetV2风格倒置残差块
在特征提取的过程中,MobileViT采用了MobileNetV2风格的倒置残差块进行特征转换和降采样。这些倒置残差块首先通过1x1卷积进行通道扩展,然后利用深度可分离卷积进行空间特征提取,最后再通过1x1卷积将特征图通道数恢复到原始大小。通过这种方式,倒置残差块能够在不增加过多计算量的前提下,有效地提高模型的特征提取能力。
3. MobileViT块
MobileViT架构的核心在于其独特的MobileViT块。这些块结合了Transformer和卷积神经网络的优点,旨在捕获图像中的长距离依赖关系和局部空间关系。具体来说,MobileViT块首先通过自注意力机制(如多头自注意力)计算特征图中不同位置之间的相关性,从而捕获长距离依赖关系。然后,它利用卷积操作对特征图进行局部空间特征的提取和融合。通过这种方式,MobileViT块能够同时利用Transformer的全局建模能力和卷积神经网络的局部特征提取能力,从而实现更高效、更准确的图像识别。
4. 输出层
经过多个MobileViT块的堆叠后,模型最终通过全局平均池化层将特征图转换为固定长度的特征向量。然后,这些特征向量被送入一个全连接层进行分类。全连接层的输出节点数与类别数相同,通过softmax函数计算每个类别的概率分布。
总体而言,MobileViT架构通过结合Transformer和卷积神经网络的优点,实现了在移动设备上进行高效、准确的图像分类。其独特的MobileViT块能够有效地捕获图像中的长距离依赖关系和局部空间关系,从而提高了模型的性能。同时,MobileViT架构还采用了MobileNetV2风格的倒置残差块进行特征转换和降采样,进一步提高了模型的计算效率。这些特点使得MobileViT成为了一个优秀的移动设备图像分类模型。
2.2.1.构建MobileViT
# 定义卷积块函数,用于构建卷积层。
def conv_block(x, filters=16, kernel_size=3, strides=2):
# 创建二维卷积层。
conv_layer = layers.Conv2D(
filters, # 过滤器数量
kernel_size, # 卷积核大小
strides=strides, # 步长
activation=keras.activations.swish, # 激活函数
padding="same", # 填充方式
)
return conv_layer(x) # 返回卷积后的输出
# 根据输入尺寸和卷积核大小,计算正确的填充量。
def correct_pad(inputs, kernel_size):
# 根据图像数据格式确定图像维度。
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
input_size = inputs.shape[img_dim : (img_dim + 2)]
# 将卷积核大小转换为元组,如果它是一个整数。
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
# 计算调整值,用于确保卷积后尺寸的正确性。
if input_size[0] is None:
adjust = (1, 1)
else:
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
# 返回需要添加的填充量。
return (
(correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]),
)
# 定义反残差块,用于构建轻量级卷积神经网络中的反残差结构。
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
# 使用1x1卷积进行通道扩展。
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
# 如果步长大于1,则使用零填充。
if strides == 2:
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
# 使用深度可分离卷积进行空间维度的降采样。
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
)(m)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
# 使用1x1卷积将通道数降至输出通道数。
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)
# 如果步长为1且输入输出通道数相同,则使用残差连接。
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m
# 定义多层感知机(MLP)函数,用于Transformer中的前馈网络。
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return x
# 定义Transformer块函数,用于构建Transformer模型中的自注意力机制。
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
for _ in range(transformer_layers):
# 第一层归一化。
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
# 创建多头注意力层。
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# 第一个残差连接。
x2 = layers.Add()([attention_output, x])
# 第二层归一化。
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP。
x3 = mlp(
x3,
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
dropout_rate=0.1,
)
# 第二个残差连接。
x = layers.Add()([x3, x2])
return x
# 定义MobileViT块,结合了局部特征提取和全局特征提取。
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
# 使用卷积进行局部特征提取。
local_features = conv_block(x, filters=projection_dim, strides=strides)
local_features = conv_block(
local_features, filters=projection_dim, kernel_size=1, strides=strides
)
# 将特征图划分为不重叠的patches,并通过Transformer块处理。
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
local_features
)
global_features = transformer_block(
non_overlapping_patches, num_blocks, projection_dim
)
# 将Transformer的输出重新整理成特征图的形状。
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
global_features
)
# 使用1x1卷积将特征图的通道数调整为与输入匹配,并与输入特征图进行拼接。
folded_feature_map = conv_block(
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
)
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
# 使用卷积层融合局部和全局特征。
local_global_features = conv_block(
local_global_features, filters=projection_dim, strides=strides
)
return local_global_features
上述代码定义了一系列用于构建和操作深度学习模型,特别是MobileViT模型的函数。
-
conv_block:
- 功能:创建一个卷积块,包含卷积层、激活函数(Swish)和批量归一化。
- 用途:用于提取图像特征,可以作为更复杂模型的一部分。
-
correct_pad:
- 功能:计算进行卷积操作时所需的填充量,以确保输出尺寸正确。
- 用途:在对输入图像进行卷积操作之前调整边界填充。
-
inverted_residual_block:
- 功能:实现MobileNetV2中的反残差结构,包含点卷积、深度卷积和批量归一化。
- 用途:构建轻量级网络结构,用于减少模型参数和计算量。
-
mlp:
- 功能:实现多层感知机(MLP),用于Transformer中的前馈网络部分。
- 用途:在Transformer模型中进行特征的非线性变换。
-
transformer_block:
- 功能:构建Transformer块,包含多头自注意力机制和前馈网络。
- 用途:处理序列数据,捕获长距离依赖关系,用于图像的全局特征提取。
-
mobilevit_block:
- 功能:结合局部特征提取(通过卷积)和全局特征提取(通过Transformer)的MobileViT块。
- 用途:作为MobileViT模型的核心组件,实现图像的高效特征提取和表示。
整体来看,这些函数共同构成了一个深度学习模型的框架,特别是针对移动设备优化的视觉Transformer模型(MobileViT)。它们涵盖了从数据预处理(如填充和归一化)到特征提取(卷积和Transformer操作)的各个步骤,最终实现图像分类或其他视觉任务。
2.2.2.实例化MobileViT块
关于MobileViT块的深入解析:
在MobileViT架构中,MobileViT块是关键组成部分,它融合了卷积和Transformer的优势。首先,输入的特征表示(A)通过一系列卷积层,这些卷积层专注于捕获图像中的局部细节和空间关系。这些特征图的典型形状是(h, w, num_channels),其中h代表高度,w代表宽度,num_channels是通道数。
随后,这些特征图被分割成一系列非重叠的小补丁(patches),每个补丁的大小为p×p,其中p表示补丁的边长。这些小补丁被重新组织成一个二维数组,形状为(p^2, n, num_channels),其中n表示整个图像中被分割成的补丁数量,计算公式为n = (h * w) / (p * p)。这个过程可以看作是“展开”操作,将二维特征图转化为一个包含多个补丁的一维序列。
接下来,这个一维序列通过Transformer块进行处理。Transformer块利用自注意力机制来捕获补丁之间的全局依赖关系,从而能够捕捉图像中的长距离依赖。这种全局建模能力是Transformer架构的核心优势,尤其对于理解复杂图像结构和识别高级别概念非常有效。
经过Transformer块处理后,输出向量(B)再次被“折叠”回二维特征图的形状(h, w, num_channels)。这个过程与之前的“展开”操作相反,它将一维序列重新组织成二维特征图,以便后续处理。
最后,原始的特征表示(A)和经过Transformer处理后的特征表示(B)通过两个额外的卷积层进行融合。这两个卷积层的作用是将局部和全局特征进行结合,生成更加丰富的特征表示。值得注意的是,在这个过程中,特征图的空间分辨率保持不变,这有助于保持模型对图像细节的敏感度。
从某种角度来看,MobileViT块可以被视为一种特殊的卷积块,它结合了卷积的局部特征提取能力和Transformer的全局建模能力。这种设计使得MobileViT架构能够在保持较低计算复杂度的同时,实现较高的图像分类准确率。
在构建MobileViT架构时,多个MobileViT块被组合在一起,形成一个完整的网络结构。以下是从原始论文中引用的示意图,展示了MobileViT架构的一个具体实例(如XXS变体):(请注意,由于这里不能直接插入图像,我们将省略具体的示意图。)
def create_mobilevit(num_classes=5):
# 定义输入层,假设输入图像大小为 image_size x image_size,具有3个颜色通道。
inputs = keras.Input((image_size, image_size, 3))
# 对输入图像进行归一化处理,将像素值缩放到0到1之间。
x = layers.Rescaling(scale=1.0 / 255)(inputs)
# 开始卷积干线部分,使用 conv_block 函数创建第一个卷积层。
x = conv_block(x, filters=16)
# 使用 inverted_residual_block 函数创建 MobileNetV2 风格的反残差块。
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=16
)
# 使用 MV2 块进行下采样。
# 第一次下采样,步长为2,输出通道数增加到24。
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
)
# 继续使用 MV2 块进行特征提取,保持通道数不变。
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
# 再次使用 MV2 块进行特征提取。
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
# 第一个 MV2 块到 MobileViT 块的转换。
# 第二次下采样,步长为2,输出通道数增加到48。
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
)
# 使用 mobilevit_block 函数创建 MobileViT 块,包含2个 Transformer 层。
x = mobilevit_block(x, num_blocks=2, projection_dim=64)
# 第二个 MV2 块到 MobileViT 块的转换。
# 继续下采样,步长为2,输出通道数增加到64。
x = inverted_residual_block(
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
)
# 使用 mobilevit_block 函数创建 MobileViT 块,包含4个 Transformer 层。
x = mobilevit_block(x, num_blocks=4, projection_dim=80)
# 第三个 MV2 块到 MobileViT 块的转换。
# 再次下采样,步长为2,输出通道数增加到80。
x = inverted_residual_block(
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
)
# 使用 mobilevit_block 函数创建 MobileViT 块,包含3个 Transformer 层。
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
# 使用 conv_block 进行1x1卷积,用于通道数的调整。
x = conv_block(x, filters=320, kernel_size=1, strides=1)
# 分类头,使用全局平均池化层和全连接层进行分类。
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
# 创建 Keras 模型,输入为之前定义的 inputs,输出为分类结果 outputs。
return keras.Model(inputs, outputs)
# 实例化 MobileViT 模型,类别数默认为5。
mobilevit_xxs = create_mobilevit()
# 打印模型的概述信息,包括每层的输出形状和参数数量。
mobilevit_xxs.summary()
这段代码定义了一个创建MobileViT模型的函数 create_mobilevit
,并实例化了这个模型,然后打印出了模型的概述。
-
函数定义:
create_mobilevit
: 这个函数接受一个参数num_classes
,表示分类任务的类别数,默认为5。
-
输入层:
inputs
: 使用keras.Input
定义模型的输入,假设输入图像的大小是image_size x image_size
,具有3个颜色通道。
-
数据预处理:
Rescaling
: 对输入图像进行重缩放,归一化到[0,1]区间。
-
初始卷积层:
conv_block
: 应用一个卷积块作为模型的起始部分。
-
反残差块:
inverted_residual_block
: 使用MobileNetV2中的反残差结构进行下采样和特征提取。
-
MobileViT块:
mobilevit_block
: 结合了卷积和Transformer结构的MobileViT块,用于提取局部和全局特征。
-
分类头:
GlobalAvgPool2D
: 使用全局平均池化层来减少特征的空间维度。Dense
: 使用全连接层进行分类,激活函数为Softmax,输出类别概率。
-
模型实例化:
mobilevit_xxs
: 调用create_mobilevit
函数实例化MobileViT模型。
-
模型概述:
summary
: 打印模型的概述信息,包括每层的名称、输出形状和参数数量。
这个函数构建了一个轻量级的深度学习模型,适用于移动设备上的图像分类任务。模型结合了卷积神经网络的局部特征提取能力和Transformer的全局特征提取能力,通过多个MobileViT块和反残差块进行特征提取,最终通过分类头输出预测结果。通过调用 mobilevit_xxs.summary()
,用户可以快速了解模型的结构和参数量。
2.3 数据预处理
2.3.1.加载数据
我们将使用 tf_flowers
数据集来演示该模型。与其他基于Transformer的架构不同,MobileViT使用了一个简单的数据增强流程,这主要是因为它具有CNN(卷积神经网络)的特性。
# 定义批次大小和自动调优参数
batch_size = 64
auto = tf.data.AUTOTUNE
# 定义在训练时使用的更大的图像尺寸
resize_bigger = 280
# 定义类别数
num_classes = 5
# 定义数据预处理函数
def preprocess_dataset(is_training=True):
# 定义内部函数,用于处理单个图像和标签
def _pp(image, label):
if is_training:
# 如果是在训练阶段,先将图像调整到更大的分辨率,然后随机裁剪到所需的尺寸
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
# 随机水平翻转图像
image = tf.image.random_flip_left_right(image)
else:
# 如果是在测试或验证阶段,直接将图像调整到所需的尺寸
image = tf.image.resize(image, (image_size, image_size))
# 将标签转换为独热编码
label = tf.one_hot(label, depth=num_classes)
return image, label
# 返回内部函数
return _pp
# 定义数据集准备函数
def prepare_dataset(dataset, is_training=True):
# 如果是在训练阶段,先对数据集进行洗牌
if is_training:
dataset = dataset.shuffle(batch_size * 10)
# 使用映射函数并行地应用预处理函数
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
# 将数据集分批并使用预取操作优化性能
return dataset.batch(batch_size).prefetch(auto)
这段代码定义了两个函数,preprocess_dataset
和 prepare_dataset
,用于准备和预处理数据集。preprocess_dataset
函数根据是否处于训练阶段,对图像执行不同的预处理操作,包括调整图像大小、随机裁剪、随机水平翻转和标签的独热编码。prepare_dataset
函数则用于对整个数据集应用预处理函数,并进行洗牌、分批处理和预取操作,以优化数据加载过程。
2.3.2. 数据预处理
# 使用 TensorFlow Datasets 库加载 tf_flowers 数据集,分为训练集和验证集。
# 训练集占90%,验证集占10%。
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
# 获取训练集和验证集的样本数量。
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
# 打印训练集和验证集的样本数量。
print(f"Number of training examples: {
num_train}") # 训练样本数
print(f"Number of validation examples: {
num_val}") # 验证样本数
# 使用之前定义的 prepare_dataset 函数准备训练集和验证集。
# 训练集使用 is_training=True 进行数据增强。
train_dataset = prepare_dataset(train_dataset, is_training=True