昇思MindSpore学习笔记6-06计算机视觉--Vision Transormer图像分类

摘要:

        记录MindSpore AI框架使用ViT模型在ImageNet图像数据分类上进行训练、验证、推理的过程和方法。包括环境准备、下载数据集、数据集加载、模型解析与构建、模型训练与推理等。

一、

1. ViT模型

Vision Transformer

自注意结构模型

Self-Attention

        Transformer模型

                能够训练具有超过100B规模的参数模型

领域

        自然语言处理

        计算机视觉

不依赖卷积操作

2.模型结构

ViT模型主体结构

从下往上

最下面主输入数据集

        原图像划分为多个patch(图像块)

                二维patch(不考虑channel)转换为一维向量

中间backbone基于Transformer模型Encoder部分

        Multi-head Attention结构

        部分结构顺序有调整

                Normalization位置不同

上面Blocks堆叠后接全连接层Head

附加输入类别向量

输出识别分类结果

二、环境准备

确保安装了Python环境和MindSpore

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据准备

1.下载、解压数据集

下载源

http://image-net.org

ImageNet数据集

本案例应用数据集是从ImageNet筛选的子集。

from download import download
​
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
​
path = download(dataset_url, path, kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)

file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

2.数据集路径结构

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

3.加载数据集

import os
​
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
​
​
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
​
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
​
trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

四、模型解析

1.Transformer基本原理

Transformer模型

基于Attention机制的编码器-解码器型结构

模型结构图:

多个Encoder和Decoder模块所组成

Encoder和Decoder详细结构图:

Encoder与Decoder结构组成

多头注意力Multi-Head Attention层

    基于自注意力Self-Attention机制

    多个Self-Attention并行组成

Feed Forward层

Normaliztion层

残差连接(Residual Connection),图中的“Add”

2.Attention模块

Self-Attention核心内容

为输入向量的每个单词学习一个权重

        给定查询向量Query

        计算Query和各个Key的相似性或者相关性

                得到注意力分布

                得到每个Key对应Value的权重系数

        对Value进行加权求和得到最终的Attention数值。

Self-Attention机制:

(1) 最初的输入向量

经过Embedding层

        映射成dim x 3

        分割成三个向量

                Q(Query)

                K(Key)

                V(Value)

输入向量为一个一维向量序列(x1,x2,x3)

每个一维向量经过Embedding层映射出Q、K、V三个向量

        只是Embedding矩阵不同

        矩阵参数通过学习得到

向量之间关联

通过Q、K、V三个矩阵可计算

其中两个向量点乘获得权重

另一个向量承载权重向加的结果

(2) 自注意力机制的自注意主要体现

Q、K、V来源于其自身

自注意过程

        提取输入的不同顺序的向量的联系与特征

        通过不同顺序向量之间的联系紧密性表现

                Q与K乘积经过Softmax的结果

获取Q,K,V向量间权重

        Q、K点乘

        除以维度的平方根

        Softmax处理所有向量的结果

(3) 全局自注意

向量V与Q、K经过Softmax结果

        weight sum

每一组Q、K、V最后都有一个V输出

当前向量结合其他向量关联权重得到结果

Self-Attention全部过程:

多头注意力机制

分割self-Attention处理的向量为多个Head部分处理

        并行加速

        保持参数总量不变

同样的query, key和value映射为高维空间(Q,K,V)

        不同子空间(Q_0,K_0,V_0)

        分开计算自注意力

        最后再合并不同子空间中的注意力信息。

同一个输入向量

多个注意力机制可以并行加速处理

处理时更充分的分析和利用了向量特征

下图中ai和aj是同一个向量分割而得

以下是Multi-Head Attention代码:

from mindspore import nn, ops
​
class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
​
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
​
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)
​
    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
​
        return out

Transformer Encoder

多结构拼接形成Transformer基础结构

Self-Attention

Feed Forward

Residual Connection

Feed Forward,Residual Connection结构代码:

from typing import Optional, Dict
​
class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
​
    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
​
        return x
​
class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell
​
    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

Self-Attention构建ViT模型中的TransformerEncoder部分:

ViT模型Transformer不同

Normalization放在Self-Attention和Feed Forward之前

其他结构不变

Transformer结构图

多个子encoder堆叠构建模型编码器

ViT模型配置超参数num_layers

        确定堆叠层数

Residual Connection,Normalization的结构

保证信息经过深层处理不退化

增强模型泛化能力

TransformerEncoder结构和多层感知器(MLP)结合

构成了ViT模型的backbone部分

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
​
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)
​
            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)
​
            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)
​
    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

ViT模型的输入

传统的Transformer结构

处理自然语言领域的词向量

(Word Embedding or Word Vector),

词向量是一维向量堆叠

图片是二维矩阵堆叠,

多头注意力机制处理一维词向量堆叠时会提取词向量之间的联系也就是上下文语义

ViT模型中:

输入图像每个channel卷积操作划分1616个patch

        一幅输入224 x 224的图像卷积处理

                得到16 x 16个patch

                每一个patch的大小就是14 x 14

每个patch矩阵拉伸成为一维向量

获得近似词向量堆叠的效果

        14 x 14patch转换为长度196的向量

图像输入网络经过的第一步处理。

Patch Embedding代码:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
​
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
​
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
​
    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
​
        return x

输入图像划分patch后

        经过pos_embedding

                class_embedding两个过程。

class_embedding借鉴BERT模型用于文本分类

每一个word vector之前增加一个类别值

196维向量加上class_embedding变为197维

class_embedding是一个可以学习的参数

经过网络的不断训练,输出向量的第一个维度的输出来决定最后的输出类别;

输入16 x 16patch

输出16x16个class_embedding进行分类。

pos_embedding也是一组可以学习的参数

        加入patch矩阵

pos_embedding有4种方案

        采用一维pos_embedding

        由于class_embedding是加在pos_embedding之前

        所以pos_embedding维度会比patch拉伸后的维度加1。

五、整体构建ViT

构建ViT模型代码

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
​
​
def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)
​
​
class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()
​
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
​
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)
​
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)
​
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)
​
    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
​
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
​
        return x

整体流程图如下所示:

六、模型训练与推理

1.模型训练

模型开始训练

设定损失函数

        优化器

        回调函数

调整epoch_size

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
​
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
​
# construct model
network = ViT()
​
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
​
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)
​
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
​
​
# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""
​
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
​
    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss
​
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
​
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
​
# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

输出:

Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)

file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms

2.模型验证

模型验证

ImageFolderDataset接口用于读取数据集

CrossEntropySmooth接口用于损失函数实例化

Model等接口用于编译模型

步骤:

数据增强

定义ViT网络结构

加载预训练模型参数

设置损失函数

设置评价指标

        Top_1_Accuracy输出最大值为预测结果

        Top_5_Accuracy输出前5的值为预测结果

        两个指标的值越大,代表模型准确率越高

编译模型

验证

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
​
trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
​
# construct model
network = ViT()
​
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
​
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
​
# evaluate model
result = model.eval(dataset_val)
print(result)

输出:

{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

3.模型推理

推理图片数据预处理

resize

normalize

匹配训练输入数据

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
​
trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

模型推理

调用模型predict方法

index2label获取对应标签

自定义show_result接口在对应图片上写结果

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
​
​
class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)
​
​
def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")
​
​
def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')
​
​
def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)
​
    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")
​
    return image
​
​
def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)
​
    image = Image.fromarray(image)
    image.save(image_path)
​
​
def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)
​
            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)
​
​
def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)
​
    if show:
        imshow(img, win_name, wait_time)
​
​
def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
​
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
​
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
​
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping
​
​
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

输出:

{236: 'Doberman'}

推理过程完成后

推理文件夹下找图片推理结果

  • 21
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

muren

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值