昇思MindSpore学习总结十四 —— Vision Transformer图像分类

1、Vision Transformer(ViTal)简介

        近些年,随着基于自注意(self-Attention)结构的模型发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展,由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

        ViT则是自然语言处理和计算机视觉两个领域的融合结晶,在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好地效果。

1.1 模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

1.2 模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

2、环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

        首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)

import os

import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms


data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

 3、模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

3.1 Transformer基本原理

        Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

 

        Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的“Add”)。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

3.2 Attention模块

        以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。
  2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。
  3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结

 通过下图可以整体把握Self-Attention的全部过程。

        多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。

        总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。

        所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的𝑎1𝑎1和𝑎2𝑎2是同一个向量进行分割获得的。

 以下是Multi-Head Attention代码,结合上文的解释,代码清晰的展现了这一过程。

from mindspore import nn, ops  # 导入MindSpore的神经网络模块(nn)和操作模块(ops)

class Attention(nn.Cell):  # 定义一个名为Attention的类,继承自nn.Cell
    def __init__(self,  # 定义类的初始化方法
                 dim: int,  # 输入维度
                 num_heads: int = 8,  # 注意力头数,默认为8
                 keep_prob: float = 1.0,  # 输出保留概率,默认为1.0
                 attention_keep_prob: float = 1.0):  # 注意力保留概率,默认为1.0
        super(Attention, self).__init__()  # 调用父类的初始化方法

        self.num_heads = num_heads  # 将num_heads赋值给实例变量self.num_heads
        head_dim = dim // num_heads  # 每个头的维度
        self.scale = ms.Tensor(head_dim ** -0.5)  # 计算缩放因子并转换为张量

        self.qkv = nn.Dense(dim, dim * 3)  # 定义一个全连接层,将输入维度映射到3倍的输入维度
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)  # 定义一个Dropout层,用于注意力机制的丢弃
        self.out = nn.Dense(dim, dim)  # 定义另一个全连接层,将输入维度映射到输出维度
        self.out_drop = nn.Dropout(p=1.0-keep_prob)  # 定义一个Dropout层,用于输出的丢弃
        self.attn_matmul_v = ops.BatchMatMul()  # 定义一个批量矩阵乘法操作
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)  # 定义一个带转置的批量矩阵乘法操作
        self.softmax = nn.Softmax(axis=-1)  # 定义一个Softmax层,用于计算注意力权重

    def construct(self, x):  # 定义类的前向传播方法
        """Attention construct."""
        b, n, c = x.shape  # 获取输入张量的形状(batch_size, num_tokens, dim)
        qkv = self.qkv(x)  # 通过全连接层获取查询、键和值
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))  # 调整形状以分离查询、键和值
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))  # 转置以便于后续计算
        q, k, v = ops.unstack(qkv, axis=0)  # 分离查询、键和值
        attn = self.q_matmul_k(q, k)  # 计算查询和键的点积
        attn = ops.mul(attn, self.scale)  # 缩放点积结果
        attn = self.softmax(attn)  # 应用Softmax函数计算注意力权重
        attn = self.attn_drop(attn)  # 进行Dropout
        out = self.attn_matmul_v(attn, v)  # 将注意力权重与值进行矩阵乘法
        out = ops.transpose(out, (0, 2, 1, 3))  # 转置结果
        out = ops.reshape(out, (b, n, c))  # 调整形状回到原输入形状
        out = self.out(out)  # 通过全连接层获取最终输出
        out = self.out_drop(out)  # 进行Dropout

        return out  # 返回最终输出

3.3 Transformer Encoder

        在了解了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,下面代码实现了Feed Forward,Residual Connection结构。

from typing import Optional, Dict  # 导入Optional和Dict类型提示

class FeedForward(nn.Cell):  # 定义一个名为FeedForward的类,继承自nn.Cell
    def __init__(self,  # 定义类的初始化方法
                 in_features: int,  # 输入特征维度
                 hidden_features: Optional[int] = None,  # 隐藏层特征维度,可选,默认与输入特征维度相同
                 out_features: Optional[int] = None,  # 输出特征维度,可选,默认与输入特征维度相同
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认使用GELU
                 keep_prob: float = 1.0):  # Dropout保留概率,默认1.0
        super(FeedForward, self).__init__()  # 调用父类的初始化方法
        out_features = out_features or in_features  # 如果未指定输出特征维度,则与输入特征维度相同
        hidden_features = hidden_features or in_features  # 如果未指定隐藏层特征维度,则与输入特征维度相同
        self.dense1 = nn.Dense(in_features, hidden_features)  # 定义第一个全连接层
        self.activation = activation()  # 定义激活函数
        self.dense2 = nn.Dense(hidden_features, out_features)  # 定义第二个全连接层
        self.dropout = nn.Dropout(p=1.0-keep_prob)  # 定义Dropout层

    def construct(self, x):  # 定义类的前向传播方法
        """Feed Forward construct."""
        x = self.dense1(x)  # 通过第一个全连接层
        x = self.activation(x)  # 应用激活函数
        x = self.dropout(x)  # 进行Dropout
        x = self.dense2(x)  # 通过第二个全连接层
        x = self.dropout(x)  # 再次进行Dropout

        return x  # 返回最终输出

class ResidualCell(nn.Cell):  # 定义一个名为ResidualCell的类,继承自nn.Cell
    def __init__(self, cell):  # 定义类的初始化方法
        super(ResidualCell, self).__init__()  # 调用父类的初始化方法
        self.cell = cell  # 将传入的cell赋值给实例变量self.cell

    def construct(self, x):  # 定义类的前向传播方法
        """ResidualCell construct."""
        return self.cell(x) + x  # 返回cell的输出与输入的和,实现残差连接

        接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分,类似于构建了一个Transformer的编码器部分,如下图[1]所示:

  1. ViT模型中的基础结构与标准Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他结构如Residual Connection,Feed Forward,Normalization都如Transformer中所设计。

  2. 从Transformer结构的图片可以发现,多个子encoder的堆叠就完成了模型编码器的构建,在ViT模型中,依然沿用这个思路,通过配置超参数num_layers,就可以确定堆叠层数。

  3. Residual Connection,Normalization的结构可以保证模型有很强的扩展性(保证信息经过深层处理不会出现退化的现象,这是Residual Connection的作用),Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器(MLP)结合,就构成了ViT模型的backbone部分。

class TransformerEncoder(nn.Cell):  # 定义一个名为TransformerEncoder的类,继承自nn.Cell
    def __init__(self,  # 定义类的初始化方法
                 dim: int,  # 输入特征维度
                 num_layers: int,  # 编码器层数
                 num_heads: int,  # 注意力头数
                 mlp_dim: int,  # 前馈神经网络隐藏层维度
                 keep_prob: float = 1.,  # Dropout保留概率,默认1.0
                 attention_keep_prob: float = 1.0,  # 注意力层Dropout保留概率,默认1.0
                 drop_path_keep_prob: float = 1.0,  # 残差连接Dropout保留概率,默认1.0
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认使用GELU
                 norm: nn.Cell = nn.LayerNorm):  # 归一化层,默认使用LayerNorm
        super(TransformerEncoder, self).__init__()  # 调用父类的初始化方法
        layers = []  # 定义一个空列表用于存储各层

        for _ in range(num_layers):  # 循环创建num_layers个编码器层
            normalization1 = norm((dim,))  # 创建第一个归一化层
            normalization2 = norm((dim,))  # 创建第二个归一化层
            attention = Attention(dim=dim,  # 创建注意力层
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,  # 创建前馈神经网络层
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(  # 将残差连接后的层添加到layers列表
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),  # 第一层:归一化 + 注意力 + 残差连接
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))  # 第二层:归一化 + 前馈神经网络 + 残差连接
                ])
            )
        self.layers = nn.SequentialCell(layers)  # 将所有层组合成一个SequentialCell

    def construct(self, x):  # 定义类的前向传播方法
        """Transformer construct."""
        return self.layers(x)  # 通过所有层,并返回最终输出

3.4 ViT模型的输入

        传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。

在ViT模型中:

  1. 通过将输入图像在每个channel上划分为1616个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;*例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。

  2. 再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示:

class PatchEmbedding(nn.Cell):  # 定义一个名为PatchEmbedding的类,继承自nn.Cell
    MIN_NUM_PATCHES = 4  # 定义一个常量,表示最小的patch数量

    def __init__(self,  # 定义类的初始化方法
                 image_size: int = 224,  # 输入图像大小,默认224
                 patch_size: int = 16,  # 每个patch的大小,默认16
                 embed_dim: int = 768,  # 嵌入维度,默认768
                 input_channels: int = 3):  # 输入通道数,默认3(例如RGB图像)
        super(PatchEmbedding, self).__init__()  # 调用父类的初始化方法

        self.image_size = image_size  # 保存输入图像大小
        self.patch_size = patch_size  # 保存patch大小
        self.num_patches = (image_size // patch_size) ** 2  # 计算patch数量
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)  # 定义卷积层,将图像分成patch并映射到嵌入维度

    def construct(self, x):  # 定义类的前向传播方法
        """Patch Embedding construct."""
        x = self.conv(x)  # 通过卷积层,将图像分成patch并映射到嵌入维度
        b, c, h, w = x.shape  # 获取卷积输出的形状(批量大小,通道,高度,宽度)
        x = ops.reshape(x, (b, c, h * w))  # 调整形状,将高度和宽度展平
        x = ops.transpose(x, (0, 2, 1))  # 转置,以便每个patch作为一个独立的输入

        return x  # 返回最终的patch嵌入

输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。

  1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。

  2. 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。

  3. pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。

  4. 由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。

        实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

        总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。

3.5 整体构建ViT

以下代码构建了一个完整的ViT模型。

from mindspore.common.initializer import Normal  # 导入Normal初始化器
from mindspore.common.initializer import initializer  # 导入initializer函数
from mindspore import Parameter  # 导入Parameter类

def init(init_type, shape, dtype, name, requires_grad):  # 定义初始化函数
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()  # 使用指定初始化器初始化数据
    return Parameter(initial, name=name, requires_grad=requires_grad)  # 返回一个带有初始化数据的参数

class ViT(nn.Cell):  # 定义一个名为ViT的类,继承自nn.Cell
    def __init__(self,  # 定义类的初始化方法
                 image_size: int = 224,  # 输入图像大小,默认224
                 input_channels: int = 3,  # 输入通道数,默认3(例如RGB图像)
                 patch_size: int = 16,  # 每个patch的大小,默认16
                 embed_dim: int = 768,  # 嵌入维度,默认768
                 num_layers: int = 12,  # 编码器层数,默认12
                 num_heads: int = 12,  # 注意力头数,默认12
                 mlp_dim: int = 3072,  # 前馈神经网络隐藏层维度,默认3072
                 keep_prob: float = 1.0,  # Dropout保留概率,默认1.0
                 attention_keep_prob: float = 1.0,  # 注意力层Dropout保留概率,默认1.0
                 drop_path_keep_prob: float = 1.0,  # 残差连接Dropout保留概率,默认1.0
                 activation: nn.Cell = nn.GELU,  # 激活函数,默认使用GELU
                 norm: Optional[nn.Cell] = nn.LayerNorm,  # 归一化层,默认使用LayerNorm
                 pool: str = 'cls') -> None:  # 池化方式,默认使用cls token
        super(ViT, self).__init__()  # 调用父类的初始化方法

        self.patch_embedding = PatchEmbedding(image_size=image_size,  # 创建patch嵌入层
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches  # 获取patch数量

        self.cls_token = init(init_type=Normal(sigma=1.0),  # 初始化cls token
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),  # 初始化位置嵌入
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool  # 保存池化方式
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)  # 定义Dropout层
        self.norm = norm((embed_dim,))  # 定义归一化层
        self.transformer = TransformerEncoder(dim=embed_dim,  # 创建Transformer编码器
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)  # 定义Dropout层
        self.dense = nn.Dense(embed_dim, num_classes)  # 定义全连接层,将嵌入维度映射到类别数

    def construct(self, x):  # 定义类的前向传播方法
        """ViT construct."""
        x = self.patch_embedding(x)  # 获取patch嵌入
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))  # 复制cls token
        x = ops.concat((cls_tokens, x), axis=1)  # 将cls token与patch嵌入连接
        x += self.pos_embedding  # 加上位置嵌入

        x = self.pos_dropout(x)  # 进行Dropout
        x = self.transformer(x)  # 通过Transformer编码器
        x = self.norm(x)  # 进行归一化
        x = x[:, 0]  # 选择cls token对应的输出
        if self.training:  # 如果在训练模式
            x = self.dropout(x)  # 进行Dropout
        x = self.dense(x)  # 通过全连接层

        return x  # 返回最终输出

整体流程图如下所示:

4、模型训练与推理

4.1 模型训练

模型开始训练前,需要设定损失函数,优化器,回调函数等。

        完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。

from mindspore.nn import LossBase  # 导入LossBase类
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint  # 导入训练相关类
from mindspore import train  # 导入训练模块

# 定义超参数
epoch_size = 10  # 训练轮数
momentum = 0.9  # 动量因子
num_classes = 1000  # 类别数
resize = 224  # 图像大小
step_size = dataset_train.get_dataset_size()  # 每轮训练步数

# 构建模型
network = ViT()  # 创建ViT模型

# 加载预训练模型
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"  # 预训练模型的URL
path = "./ckpt/vit_b_16_224.ckpt"  # 本地存储路径

vit_path = download(vit_url, path, replace=True)  # 下载预训练模型
param_dict = ms.load_checkpoint(vit_path)  # 加载模型参数
ms.load_param_into_net(network, param_dict)  # 将参数加载到网络中

# 定义学习率
lr = nn.cosine_decay_lr(min_lr=float(0),  # 最小学习率
                        max_lr=0.00005,  # 最大学习率
                        total_step=epoch_size * step_size,  # 总步数
                        step_per_epoch=step_size,  # 每轮步数
                        decay_epoch=10)  # 衰减轮数

# 定义优化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)  # 创建Adam优化器

# 定义损失函数
class CrossEntropySmooth(LossBase):  # 定义交叉熵损失函数类
    """CrossEntropy."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()  # 调用父类初始化方法
        self.onehot = ops.OneHot()  # 创建OneHot操作
        self.sparse = sparse  # 是否使用稀疏标签
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)  # onehot中的正值
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)  # onehot中的负值
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)  # 创建Softmax交叉熵损失

    def construct(self, logit, label):  # 定义前向传播方法
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)  # 将标签转换为onehot编码
        loss = self.ce(logit, label)  # 计算交叉熵损失
        return loss  # 返回损失值

network_loss = CrossEntropySmooth(sparse=True,  # 创建交叉熵损失实例
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# 设置检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)  # 配置检查点参数
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)  # 创建检查点回调

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")  # 检查是否使用Ascend设备
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")  # 使用混合精度
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")  # 不使用混合精度

# 训练模型
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],  # 使用检查点、损失监控和时间监控回调
            dataset_sink_mode=False)  # 不使用数据集下沉模式

4.2 模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。

ImageFolderDataset主要用于读取数据集。

CrossEntropySmooth是损失函数实例化接口。

Model主要用于编译模型。

        与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。

        在本案例中,这两个指标代表了在输出的1000维向量中,以最大值或前5的输出值所代表的类别为预测结果时,模型预测的准确率。这两个指标的值越大,代表模型准确率越高。

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)  # 创建验证数据集并打乱顺序

trans_val = [  # 定义数据预处理步骤
    transforms.Decode(),  # 解码图像
    transforms.Resize(224 + 32),  # 调整图像大小
    transforms.CenterCrop(224),  # 中心裁剪图像
    transforms.Normalize(mean=mean, std=std),  # 标准化图像
    transforms.HWC2CHW()  # 转换图像维度
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])  # 应用预处理步骤
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)  # 将数据集分成批次

# 构建模型
network = ViT()  # 创建ViT模型

# 加载预训练模型参数
param_dict = ms.load_checkpoint(vit_path)  # 加载模型参数
ms.load_param_into_net(network, param_dict)  # 将参数加载到网络中

# 定义损失函数
network_loss = CrossEntropySmooth(sparse=True,  # 创建交叉熵损失实例
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# 定义评估指标
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),  # 定义Top-1准确率
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}  # 定义Top-5准确率

if ascend_target:  # 检查是否使用Ascend设备
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")  # 使用混合精度
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")  # 不使用混合精度

# 评估模型
result = model.eval(dataset_val)  # 在验证数据集上评估模型
print(result)  # 打印评估结果

         从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。

4.3 模型推理

        在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。

        本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

接下来,我们将调用模型的predict方法进行模型。

        在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。

import os  # 导入os模块,用于文件和目录操作
import pathlib  # 导入pathlib模块,用于处理路径
import cv2  # 导入OpenCV库,用于图像处理
import numpy as np  # 导入NumPy库,用于数值计算
from PIL import Image  # 导入PIL库中的Image类,用于图像处理
from enum import Enum  # 导入枚举类,用于定义枚举类型
from scipy import io  # 导入SciPy库中的io模块,用于读取mat文件

class Color(Enum):  # 定义一个Color枚举类
    """Define enum color."""
    red = (0, 0, 255)  # 红色
    green = (0, 255, 0)  # 绿色
    blue = (255, 0, 0)  # 蓝色
    cyan = (255, 255, 0)  # 青色
    yellow = (0, 255, 255)  # 黄色
    magenta = (255, 0, 255)  # 洋红色
    white = (255, 255, 255)  # 白色
    black = (0, 0, 0)  # 黑色

def check_file_exist(file_name: str):  # 定义一个检查文件是否存在的函数
    """Check if a file exists."""
    if not os.path.isfile(file_name):  # 如果文件不存在,抛出文件未找到异常
        raise FileNotFoundError(f"File `{file_name}` does not exist.")

def color_val(color):  # 定义一个颜色值转换函数
    """Convert color to BGR format."""
    if isinstance(color, str):  # 如果颜色是字符串类型
        return Color[color].value  # 返回枚举对应的颜色值
    if isinstance(color, Color):  # 如果颜色是Color枚举类型
        return color.value  # 返回枚举对应的颜色值
    if isinstance(color, tuple):  # 如果颜色是元组类型
        assert len(color) == 3  # 断言元组长度为3
        for channel in color:  # 检查每个通道的值
            assert 0 <= channel <= 255  # 断言通道值在0到255之间
        return color  # 返回颜色值
    if isinstance(color, int):  # 如果颜色是整数类型
        assert 0 <= color <= 255  # 断言颜色值在0到255之间
        return color, color, color  # 返回三个相同的通道值
    if isinstance(color, np.ndarray):  # 如果颜色是NumPy数组
        assert color.ndim == 1 and color.size == 3  # 断言数组为一维且大小为3
        assert np.all((color >= 0) & (color <= 255))  # 断言所有通道值在0到255之间
        color = color.astype(np.uint8)  # 将颜色值转换为无符号8位整型
        return tuple(color)  # 返回颜色值
    raise TypeError(f'Invalid type for color: {type(color)}')  # 如果类型不匹配,抛出类型错误

def imread(image, mode=None):
    """读取图像。"""
    if isinstance(image, pathlib.Path):  # 如果image是pathlib.Path对象
        image = str(image)  # 将其转换为字符串

    if isinstance(image, np.ndarray):  # 如果image是NumPy数组
        pass  # 什么都不做
    elif isinstance(image, str):  # 如果image是字符串
        check_file_exist(image)  # 检查文件是否存在
        image = Image.open(image)  # 打开图像
        if mode:  # 如果指定了模式
            image = np.array(image.convert(mode))  # 将图像转换为指定模式,并转换为NumPy数组
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")  # 如果image既不是NumPy数组也不是字符串或Path对象,抛出类型错误

    return image  # 返回图像


def imwrite(image, image_path, auto_mkdir=True):
    """保存图像。"""
    if auto_mkdir:  # 如果自动创建目录
        dir_name = os.path.abspath(os.path.dirname(image_path))  # 获取图像路径的目录名
        if dir_name != '':  # 如果目录名不为空
            dir_name = os.path.expanduser(dir_name)  # 展开用户目录
            os.makedirs(dir_name, mode=777, exist_ok=True)  # 创建目录,权限为777,如果目录已存在则不报错

    image = Image.fromarray(image)  # 将NumPy数组转换为PIL图像
    image.save(image_path)  # 保存图像到指定路径


def imshow(img, win_name='', wait_time=0):
    """显示图像"""
    cv2.imshow(win_name, imread(img))  # 显示图像
    if wait_time == 0:  # 如果等待时间为0
        while True:
            ret = cv2.waitKey(1)  # 每1毫秒检测一次按键输入

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # 如果用户关闭了窗口或者按下了某个键
            if closed or ret != -1:
                break  # 退出循环
    else:
        ret = cv2.waitKey(wait_time)  # 等待指定时间


def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """在图像上标注预测结果。"""
    img = imread(img, mode="RGB")  # 读取图像并转换为RGB模式
    img = img.copy()  # 复制图像
    x, y = 0, row_width  # 初始化文本位置
    text_color = color_val(text_color)  # 获取文本颜色
    for k, v in result.items():  # 遍历结果字典
        if isinstance(v, float):
            v = f'{v:.2f}'  # 将浮点数格式化为两位小数
        label_text = f'{k}: {v}'  # 构造标签文本
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)  # 在图像上绘制文本
        y += row_width  # 更新文本位置
    if out_file:  # 如果指定了输出文件
        show = False  # 不显示图像
        imwrite(img, out_file)  # 保存图像

    if show:  # 如果需要显示图像
        imshow(img, win_name, wait_time)  # 显示图像

def index2label():
    """返回ImageNet数据集的图像编号和类别的字典。"""
    # 获取meta.mat文件的路径
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    # 加载meta.mat文件,获取synsets信息
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    # 获取包含子类数量的列表
    nums_children = list(zip(*meta))[4]
    # 过滤出没有子类的synsets
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    # 提取关键信息
    _, wnids, classes = list(zip(*meta))[:3]
    # 解析类名
    clssname = [tuple(clss.split(', ')) for clss in classes]
    # 创建从wnid到类名的映射字典
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    # 按wnid排序的类名列表
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    # 创建从索引到类名的映射字典
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping


# 读取推理数据
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]  # 提取图像数据
    image = ms.Tensor(image)  # 将图像数据转换为Tensor
    prob = model.predict(image)  # 预测图像类别
    label = np.argmax(prob.asnumpy(), axis=1)  # 获取预测结果的最大概率的索引
    mapping = index2label()  # 获取索引到类名的映射
    output = {int(label): mapping[int(label)]}  # 创建输出结果字典
    print(output)  # 打印输出结果
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")  # 显示并保存结果图像

         推理过程完成后,在推理文件夹下可以找到图片的推理结果,可以看出预测结果是Doberman,与期望结果相同,验证了模型的准确性。

打卡

  • 6
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值