昇思25天学习打卡营第五天|应用实践/计算机视觉/ShuffleNet图像分类

心得

通过这个课程,我们在机器视觉方面,又多了一个武器。

不过,在最前面,先说一些操作经验。

上次提到,说CPU的计算,比Ascend的计算要快稳。这次尝试了一下。果然如此,但是一看运行结果,再仔细看代码。。。

CPU:epoch=1

Ascend:epoch=5

速度是这样子来的啊。

不管怎么样,操作流程是一样的。仅仅学习也是可以的。

打卡截图

ShuffleNet图像分类

当前案例不支持在GPU设备上静态图模式运行,其他模式运行皆支持。

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

了解ShuffleNet更多详细内容,详见论文ShuffleNet

如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

shufflenet1

图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.

模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

shufflenet2

图片来源:Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.

Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1*k*k,则卷积核参数量为:in_channels*k*k,得到的feature maps通道数与输入通道数相等

Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×11×1,卷积核参数量为(in_channels/g*1*1)*out_channels。

[1]:

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0rc1,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

[2]:

# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.3.0rc1
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

[3]:

 
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor
class GroupConv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
        super(GroupConv, self).__init__()
        self.groups = groups
        self.convs = nn.CellList()
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
    def construct(self, x):
        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
        outputs = ()
        for i in range(self.groups):
            outputs = outputs + (self.convs[i](features[i].astype("float32")),)
        out = ops.cat(outputs, axis=1)
        return out

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。

为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

shufflenet3

如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

shufflenet4

为了阅读方便,将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。

ShuffleNet模块

如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), (c)的更改:

  1. 将开始和最后的1×11×1卷积模块(降维、升维)改成Point Wise Group Convolution;

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;

  3. 降采样模块中,3×33×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的3×33×3平均池化,并把相加改成拼接。

shufflenet5

[4]:

 
class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out
    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

构建ShuffleNet网络

ShuffleNet网络结构如下图所示,以输入图像224×224224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为3×33×3,stride为2的卷积层,输出特征图大小为112×112112×112,channel为24;然后通过stride为2的最大池化层,输出特征图大小为56×5656×56,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图(c)),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为1×1×9601×1×960,再经过全连接层和softmax,得到分类概率。

shufflenet6

[5]:

 
class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

训练集准备与加载

采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。

[6]:

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB)

file_sizes: 100%|████████████████████████████| 170M/170M [00:01<00:00, 99.4MB/s]
Extracting tar.gz file...
Successfully downloaded / unzipped to ./dataset

[6]:

'./dataset'

由于时间原因,对部分参数进行了调整(num_samples,batch_size..),导致训练的模型效果不好,建议根据需求进行调整,建议在 Ascend/GPU 环境体验该教程。

原教程及参数,请参考ShuffleNet图像分类

[7]:

 
import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms
def get_dataset(train_dataset_path, batch_size, usage):
    image_trans = []
    if usage == "train":
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    elif usage == "test":
        image_trans = [
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    label_trans = transforms.TypeCast(ms.int32)
    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True, num_samples=2000)
    dataset = dataset.map(image_trans, 'image')
    dataset = dataset.map(label_trans, 'label')
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset
dataset = get_dataset("./dataset/cifar-10-batches-bin", 4, "train")
batches_per_epoch = dataset.get_dataset_size()

模型训练

本节用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。

[8]:

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy
def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="CPU")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*2,
                                                batches_per_epoch,
                                                decay_epoch=2)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]
    print("============== Starting Training ==============")
    start_time = time.time()
    model.train(1, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")
if __name__ == '__main__':
    train()
model size is  2.0x
============== Starting Training ==============
epoch: 1 step: 1, loss is 2.6258254051208496
epoch: 1 step: 2, loss is 39.794734954833984
epoch: 1 step: 3, loss is 44.542945861816406
epoch: 1 step: 4, loss is 75.65628814697266
epoch: 1 step: 5, loss is 127.48133850097656
epoch: 1 step: 6, loss is 87.5020751953125
epoch: 1 step: 7, loss is 101.62273406982422
epoch: 1 step: 8, loss is 155.16738891601562
epoch: 1 step: 9, loss is 241.32371520996094
epoch: 1 step: 10, loss is 227.01644897460938
epoch: 1 step: 11, loss is 138.65396118164062
epoch: 1 step: 12, loss is 176.27835083007812
epoch: 1 step: 13, loss is 307.84295654296875
epoch: 1 step: 14, loss is 78.77322387695312
epoch: 1 step: 15, loss is 73.79924774169922
epoch: 1 step: 16, loss is 104.46890258789062
epoch: 1 step: 17, loss is 42.950660705566406
epoch: 1 step: 18, loss is 65.44954681396484
epoch: 1 step: 19, loss is 90.81419372558594
epoch: 1 step: 20, loss is 56.11922836303711
epoch: 1 step: 21, loss is 23.536577224731445
epoch: 1 step: 22, loss is 47.05916213989258
epoch: 1 step: 23, loss is 100.97496795654297
epoch: 1 step: 24, loss is 144.5192108154297
epoch: 1 step: 25, loss is 93.51004791259766
epoch: 1 step: 26, loss is 50.41395568847656
epoch: 1 step: 27, loss is 15.169258117675781
epoch: 1 step: 28, loss is 49.37852478027344
epoch: 1 step: 29, loss is 39.40176773071289
epoch: 1 step: 30, loss is 20.11125373840332
epoch: 1 step: 31, loss is 19.932083129882812
epoch: 1 step: 32, loss is 25.057092666625977
epoch: 1 step: 33, loss is 14.10660171508789
epoch: 1 step: 34, loss is 26.354232788085938
epoch: 1 step: 35, loss is 6.359532356262207
epoch: 1 step: 36, loss is 12.980948448181152
epoch: 1 step: 37, loss is 26.174339294433594
epoch: 1 step: 38, loss is 13.695343971252441
epoch: 1 step: 39, loss is 29.925888061523438
epoch: 1 step: 40, loss is 8.584239959716797
epoch: 1 step: 41, loss is 21.931264877319336
epoch: 1 step: 42, loss is 12.06672477722168
epoch: 1 step: 43, loss is 22.264915466308594
epoch: 1 step: 44, loss is 4.067727088928223
epoch: 1 step: 45, loss is 6.653378486633301
epoch: 1 step: 46, loss is 10.231215476989746
epoch: 1 step: 47, loss is 2.7845568656921387
epoch: 1 step: 48, loss is 5.390234470367432
epoch: 1 step: 49, loss is 3.0019350051879883
epoch: 1 step: 50, loss is 3.331113576889038
epoch: 1 step: 51, loss is 2.9404754638671875
epoch: 1 step: 52, loss is 2.819676160812378
epoch: 1 step: 53, loss is 2.6152536869049072
epoch: 1 step: 54, loss is 4.664990425109863
epoch: 1 step: 55, loss is 3.8249380588531494
epoch: 1 step: 56, loss is 2.2946901321411133
epoch: 1 step: 57, loss is 5.391695976257324
epoch: 1 step: 58, loss is 2.6904654502868652
epoch: 1 step: 59, loss is 2.123711585998535
epoch: 1 step: 60, loss is 2.2266998291015625
epoch: 1 step: 61, loss is 3.560650587081909
epoch: 1 step: 62, loss is 10.60955810546875
epoch: 1 step: 63, loss is 12.901015281677246
epoch: 1 step: 64, loss is 11.2257080078125
epoch: 1 step: 65, loss is 6.704418182373047
epoch: 1 step: 66, loss is 4.425196170806885
epoch: 1 step: 67, loss is 7.6208319664001465
epoch: 1 step: 68, loss is 3.637984275817871
epoch: 1 step: 69, loss is 7.687455654144287
epoch: 1 step: 70, loss is 5.50910758972168
epoch: 1 step: 71, loss is 6.786325454711914
epoch: 1 step: 72, loss is 3.295285940170288
epoch: 1 step: 73, loss is 6.319259166717529
epoch: 1 step: 74, loss is 4.665078639984131
epoch: 1 step: 75, loss is 3.1605069637298584
epoch: 1 step: 76, loss is 3.25880765914917
epoch: 1 step: 77, loss is 1.940742015838623
epoch: 1 step: 78, loss is 4.281459808349609
epoch: 1 step: 79, loss is 2.718125104904175
epoch: 1 step: 80, loss is 3.1851389408111572
epoch: 1 step: 81, loss is 5.275569915771484
epoch: 1 step: 82, loss is 3.299849271774292
epoch: 1 step: 83, loss is 5.063135147094727
epoch: 1 step: 84, loss is 2.1495776176452637
epoch: 1 step: 85, loss is 5.728967666625977
epoch: 1 step: 86, loss is 4.1019206047058105
epoch: 1 step: 87, loss is 3.6750073432922363
epoch: 1 step: 88, loss is 3.359067440032959
epoch: 1 step: 89, loss is 2.8257193565368652
epoch: 1 step: 90, loss is 5.062455654144287
epoch: 1 step: 91, loss is 4.298058986663818
epoch: 1 step: 92, loss is 2.795989513397217
epoch: 1 step: 93, loss is 3.003056287765503
epoch: 1 step: 94, loss is 2.682116985321045
epoch: 1 step: 95, loss is 3.50089168548584
epoch: 1 step: 96, loss is 3.6953463554382324
epoch: 1 step: 97, loss is 3.617540121078491
epoch: 1 step: 98, loss is 3.163033962249756
epoch: 1 step: 99, loss is 2.6230764389038086
epoch: 1 step: 100, loss is 2.1377086639404297
epoch: 1 step: 101, loss is 2.177856206893921
epoch: 1 step: 102, loss is 2.5667147636413574
epoch: 1 step: 103, loss is 2.4466171264648438
epoch: 1 step: 104, loss is 4.172312259674072
epoch: 1 step: 105, loss is 2.986344814300537
epoch: 1 step: 106, loss is 3.572352886199951
epoch: 1 step: 107, loss is 3.6122045516967773
epoch: 1 step: 108, loss is 2.0893056392669678
epoch: 1 step: 109, loss is 2.4502604007720947
epoch: 1 step: 110, loss is 2.1798033714294434
epoch: 1 step: 111, loss is 2.139862537384033
epoch: 1 step: 112, loss is 2.40075945854187
epoch: 1 step: 113, loss is 2.257197618484497
epoch: 1 step: 114, loss is 2.3176167011260986
epoch: 1 step: 115, loss is 1.8200316429138184
epoch: 1 step: 116, loss is 2.790212631225586
epoch: 1 step: 117, loss is 3.8091580867767334
epoch: 1 step: 118, loss is 2.492215633392334
epoch: 1 step: 119, loss is 3.1819939613342285
epoch: 1 step: 120, loss is 2.3661324977874756
epoch: 1 step: 121, loss is 2.6036696434020996
epoch: 1 step: 122, loss is 2.20636248588562
epoch: 1 step: 123, loss is 2.164297580718994
epoch: 1 step: 124, loss is 1.981563687324524
epoch: 1 step: 125, loss is 1.9817254543304443
epoch: 1 step: 126, loss is 2.7873992919921875
epoch: 1 step: 127, loss is 2.66616153717041
epoch: 1 step: 128, loss is 3.3393306732177734
epoch: 1 step: 129, loss is 2.975348949432373
epoch: 1 step: 130, loss is 1.8029996156692505
epoch: 1 step: 131, loss is 3.2323436737060547
epoch: 1 step: 132, loss is 3.095515489578247
epoch: 1 step: 133, loss is 1.7939496040344238
epoch: 1 step: 134, loss is 2.595576763153076
epoch: 1 step: 135, loss is 2.4362661838531494
epoch: 1 step: 136, loss is 2.508674144744873
epoch: 1 step: 137, loss is 2.927971839904785
epoch: 1 step: 138, loss is 2.006402015686035
epoch: 1 step: 139, loss is 2.460578680038452
epoch: 1 step: 140, loss is 2.9014501571655273
epoch: 1 step: 141, loss is 1.815853238105774
epoch: 1 step: 142, loss is 2.3019559383392334
epoch: 1 step: 143, loss is 2.4128715991973877
epoch: 1 step: 144, loss is 3.326815605163574
epoch: 1 step: 145, loss is 2.200629234313965
epoch: 1 step: 146, loss is 2.0387496948242188
epoch: 1 step: 147, loss is 2.707109212875366
epoch: 1 step: 148, loss is 2.6215920448303223
epoch: 1 step: 149, loss is 2.3630850315093994
epoch: 1 step: 150, loss is 2.377840042114258
epoch: 1 step: 151, loss is 2.550588607788086
epoch: 1 step: 152, loss is 2.06246018409729
epoch: 1 step: 153, loss is 2.3065648078918457
epoch: 1 step: 154, loss is 2.236825466156006
epoch: 1 step: 155, loss is 2.632615804672241
epoch: 1 step: 156, loss is 2.2588086128234863
epoch: 1 step: 157, loss is 2.6022658348083496
epoch: 1 step: 158, loss is 2.410961389541626
epoch: 1 step: 159, loss is 1.991721272468567
epoch: 1 step: 160, loss is 2.3659610748291016
epoch: 1 step: 161, loss is 1.981337547302246
epoch: 1 step: 162, loss is 2.2522175312042236
epoch: 1 step: 163, loss is 2.3218607902526855
epoch: 1 step: 164, loss is 2.272987127304077
epoch: 1 step: 165, loss is 2.4044132232666016
epoch: 1 step: 166, loss is 2.0630550384521484
epoch: 1 step: 167, loss is 2.469780921936035
epoch: 1 step: 168, loss is 2.350719928741455
epoch: 1 step: 169, loss is 2.135317325592041
epoch: 1 step: 170, loss is 2.5933918952941895
epoch: 1 step: 171, loss is 2.3896148204803467
epoch: 1 step: 172, loss is 2.3143160343170166
epoch: 1 step: 173, loss is 1.9470417499542236
epoch: 1 step: 174, loss is 2.2532448768615723
epoch: 1 step: 175, loss is 2.1132826805114746
epoch: 1 step: 176, loss is 2.317906379699707
epoch: 1 step: 177, loss is 2.7141761779785156
epoch: 1 step: 178, loss is 2.1723790168762207
epoch: 1 step: 179, loss is 2.565363645553589
epoch: 1 step: 180, loss is 2.4471046924591064
epoch: 1 step: 181, loss is 2.376852035522461
epoch: 1 step: 182, loss is 1.9931414127349854
epoch: 1 step: 183, loss is 2.525416851043701
epoch: 1 step: 184, loss is 1.9096590280532837
epoch: 1 step: 185, loss is 2.028379201889038
epoch: 1 step: 186, loss is 1.9379422664642334
epoch: 1 step: 187, loss is 2.609530448913574
epoch: 1 step: 188, loss is 1.987502098083496
epoch: 1 step: 189, loss is 2.467780351638794
epoch: 1 step: 190, loss is 2.179129123687744
epoch: 1 step: 191, loss is 2.342135190963745
epoch: 1 step: 192, loss is 2.4540958404541016
epoch: 1 step: 193, loss is 1.7237451076507568
epoch: 1 step: 194, loss is 2.3453142642974854
epoch: 1 step: 195, loss is 2.4298312664031982
epoch: 1 step: 196, loss is 2.6432907581329346
epoch: 1 step: 197, loss is 2.034980535507202
epoch: 1 step: 198, loss is 1.9781975746154785
epoch: 1 step: 199, loss is 2.5176000595092773
epoch: 1 step: 200, loss is 2.403144598007202
epoch: 1 step: 201, loss is 3.2028563022613525
epoch: 1 step: 202, loss is 2.2373926639556885
epoch: 1 step: 203, loss is 2.0563879013061523
epoch: 1 step: 204, loss is 2.3786404132843018
epoch: 1 step: 205, loss is 2.0180726051330566
epoch: 1 step: 206, loss is 1.694482684135437
epoch: 1 step: 207, loss is 2.325190305709839
epoch: 1 step: 208, loss is 2.433765172958374
epoch: 1 step: 209, loss is 2.7312464714050293
epoch: 1 step: 210, loss is 1.947196125984192
epoch: 1 step: 211, loss is 1.9539406299591064
epoch: 1 step: 212, loss is 2.4014337062835693
epoch: 1 step: 213, loss is 2.3533763885498047
epoch: 1 step: 214, loss is 1.6875959634780884
epoch: 1 step: 215, loss is 2.9546163082122803
epoch: 1 step: 216, loss is 2.352206230163574
epoch: 1 step: 217, loss is 2.22822904586792
epoch: 1 step: 218, loss is 1.7404886484146118
epoch: 1 step: 219, loss is 2.425816297531128
epoch: 1 step: 220, loss is 2.098177433013916
epoch: 1 step: 221, loss is 2.1560072898864746
epoch: 1 step: 222, loss is 2.641177177429199
epoch: 1 step: 223, loss is 3.201406717300415
epoch: 1 step: 224, loss is 1.9623440504074097
epoch: 1 step: 225, loss is 2.405972719192505
epoch: 1 step: 226, loss is 1.9641368389129639
epoch: 1 step: 227, loss is 2.1503448486328125
epoch: 1 step: 228, loss is 2.583299160003662
epoch: 1 step: 229, loss is 2.2811355590820312
epoch: 1 step: 230, loss is 2.3087921142578125
epoch: 1 step: 231, loss is 3.0039772987365723
epoch: 1 step: 232, loss is 2.1656908988952637
epoch: 1 step: 233, loss is 2.2291247844696045
epoch: 1 step: 234, loss is 2.264976739883423
epoch: 1 step: 235, loss is 2.2253987789154053
epoch: 1 step: 236, loss is 2.0790657997131348
epoch: 1 step: 237, loss is 2.168771743774414
epoch: 1 step: 238, loss is 2.28842830657959
epoch: 1 step: 239, loss is 2.1632187366485596
epoch: 1 step: 240, loss is 2.036649703979492
epoch: 1 step: 241, loss is 2.153747797012329
epoch: 1 step: 242, loss is 2.152430772781372
epoch: 1 step: 243, loss is 2.5319528579711914
epoch: 1 step: 244, loss is 2.431246042251587
epoch: 1 step: 245, loss is 2.685051441192627
epoch: 1 step: 246, loss is 2.087167739868164
epoch: 1 step: 247, loss is 2.0803401470184326
epoch: 1 step: 248, loss is 2.17630672454834
epoch: 1 step: 249, loss is 2.0943384170532227
epoch: 1 step: 250, loss is 2.2290658950805664
epoch: 1 step: 251, loss is 2.9906554222106934
epoch: 1 step: 252, loss is 2.0231692790985107
epoch: 1 step: 253, loss is 2.010467767715454
epoch: 1 step: 254, loss is 2.873980760574341
epoch: 1 step: 255, loss is 2.346977949142456
epoch: 1 step: 256, loss is 2.405374050140381
epoch: 1 step: 257, loss is 2.332207202911377
epoch: 1 step: 258, loss is 2.884948492050171
epoch: 1 step: 259, loss is 2.048842191696167
epoch: 1 step: 260, loss is 2.1544370651245117
epoch: 1 step: 261, loss is 2.6596505641937256
epoch: 1 step: 262, loss is 2.32731032371521
epoch: 1 step: 263, loss is 2.090808391571045
epoch: 1 step: 264, loss is 2.712474822998047
epoch: 1 step: 265, loss is 2.326045274734497
epoch: 1 step: 266, loss is 2.2843425273895264
epoch: 1 step: 267, loss is 2.4332685470581055
epoch: 1 step: 268, loss is 2.0449941158294678
epoch: 1 step: 269, loss is 2.365168571472168
epoch: 1 step: 270, loss is 2.3919451236724854
epoch: 1 step: 271, loss is 1.9652740955352783
epoch: 1 step: 272, loss is 1.9539532661437988
epoch: 1 step: 273, loss is 2.4892454147338867
epoch: 1 step: 274, loss is 2.0492358207702637
epoch: 1 step: 275, loss is 2.3324999809265137
epoch: 1 step: 276, loss is 2.4011893272399902
epoch: 1 step: 277, loss is 2.118913173675537
epoch: 1 step: 278, loss is 2.049823045730591
epoch: 1 step: 279, loss is 2.1467506885528564
epoch: 1 step: 280, loss is 2.1704981327056885
epoch: 1 step: 281, loss is 2.1933822631835938
epoch: 1 step: 282, loss is 2.1281886100769043
epoch: 1 step: 283, loss is 2.3650872707366943
epoch: 1 step: 284, loss is 2.231043577194214
epoch: 1 step: 285, loss is 1.9875519275665283
epoch: 1 step: 286, loss is 2.15653133392334
epoch: 1 step: 287, loss is 2.5786375999450684
epoch: 1 step: 288, loss is 2.3581206798553467
epoch: 1 step: 289, loss is 1.7174850702285767
epoch: 1 step: 290, loss is 1.8790639638900757
epoch: 1 step: 291, loss is 2.380164861679077
epoch: 1 step: 292, loss is 2.1249706745147705
epoch: 1 step: 293, loss is 2.1836869716644287
epoch: 1 step: 294, loss is 2.7065696716308594
epoch: 1 step: 295, loss is 2.3943002223968506
epoch: 1 step: 296, loss is 2.286813497543335
epoch: 1 step: 297, loss is 1.9807648658752441
epoch: 1 step: 298, loss is 2.0836973190307617
epoch: 1 step: 299, loss is 1.9861990213394165
epoch: 1 step: 300, loss is 1.9967654943466187
epoch: 1 step: 301, loss is 1.9926495552062988
epoch: 1 step: 302, loss is 2.073901653289795
epoch: 1 step: 303, loss is 1.9408005475997925
epoch: 1 step: 304, loss is 2.432864189147949
epoch: 1 step: 305, loss is 2.2370548248291016
epoch: 1 step: 306, loss is 2.39937424659729
epoch: 1 step: 307, loss is 2.391186237335205
epoch: 1 step: 308, loss is 2.6636369228363037
epoch: 1 step: 309, loss is 2.2630040645599365
epoch: 1 step: 310, loss is 3.0700266361236572
epoch: 1 step: 311, loss is 2.2472522258758545
epoch: 1 step: 312, loss is 2.138871192932129
epoch: 1 step: 313, loss is 2.2199056148529053
epoch: 1 step: 314, loss is 2.289818048477173
epoch: 1 step: 315, loss is 2.0501785278320312
epoch: 1 step: 316, loss is 2.20554256439209
epoch: 1 step: 317, loss is 1.8506948947906494
epoch: 1 step: 318, loss is 2.431769371032715
epoch: 1 step: 319, loss is 2.232560634613037
epoch: 1 step: 320, loss is 2.251526355743408
epoch: 1 step: 321, loss is 2.6219663619995117
epoch: 1 step: 322, loss is 2.520749092102051
epoch: 1 step: 323, loss is 2.5345115661621094
epoch: 1 step: 324, loss is 2.373814105987549
epoch: 1 step: 325, loss is 2.5523581504821777
epoch: 1 step: 326, loss is 1.8571141958236694
epoch: 1 step: 327, loss is 2.4886703491210938
epoch: 1 step: 328, loss is 2.301044225692749
epoch: 1 step: 329, loss is 2.152078628540039
epoch: 1 step: 330, loss is 2.2027933597564697
epoch: 1 step: 331, loss is 1.9776010513305664
epoch: 1 step: 332, loss is 2.523000717163086
epoch: 1 step: 333, loss is 1.7177644968032837
epoch: 1 step: 334, loss is 2.4049694538116455
epoch: 1 step: 335, loss is 2.422466516494751
epoch: 1 step: 336, loss is 2.44610333442688
epoch: 1 step: 337, loss is 2.297590494155884
epoch: 1 step: 338, loss is 2.2462000846862793
epoch: 1 step: 339, loss is 2.4756531715393066
epoch: 1 step: 340, loss is 2.180751323699951
epoch: 1 step: 341, loss is 2.196174144744873
epoch: 1 step: 342, loss is 2.0754973888397217
epoch: 1 step: 343, loss is 2.373931884765625
epoch: 1 step: 344, loss is 2.4620656967163086
epoch: 1 step: 345, loss is 2.461794137954712
epoch: 1 step: 346, loss is 2.27624773979187
epoch: 1 step: 347, loss is 2.274203062057495
epoch: 1 step: 348, loss is 3.3738486766815186
epoch: 1 step: 349, loss is 2.6414642333984375
epoch: 1 step: 350, loss is 2.34397029876709
epoch: 1 step: 351, loss is 2.1481738090515137
epoch: 1 step: 352, loss is 2.7757787704467773
epoch: 1 step: 353, loss is 2.6019771099090576
epoch: 1 step: 354, loss is 1.9724563360214233
epoch: 1 step: 355, loss is 2.566990375518799
epoch: 1 step: 356, loss is 2.2047133445739746
epoch: 1 step: 357, loss is 2.2065536975860596
epoch: 1 step: 358, loss is 2.4668900966644287
epoch: 1 step: 359, loss is 2.427429437637329
epoch: 1 step: 360, loss is 2.1901979446411133
epoch: 1 step: 361, loss is 2.2420694828033447
epoch: 1 step: 362, loss is 2.4124679565429688
epoch: 1 step: 363, loss is 2.182917594909668
epoch: 1 step: 364, loss is 2.4764533042907715
epoch: 1 step: 365, loss is 2.389598846435547
epoch: 1 step: 366, loss is 2.3450076580047607
epoch: 1 step: 367, loss is 2.284268379211426
epoch: 1 step: 368, loss is 2.1783335208892822
epoch: 1 step: 369, loss is 2.445622205734253
epoch: 1 step: 370, loss is 2.401132106781006
epoch: 1 step: 371, loss is 2.1972923278808594
epoch: 1 step: 372, loss is 2.389873743057251
epoch: 1 step: 373, loss is 2.080435037612915
epoch: 1 step: 374, loss is 2.4331464767456055
epoch: 1 step: 375, loss is 2.0456268787384033
epoch: 1 step: 376, loss is 2.343022346496582
epoch: 1 step: 377, loss is 1.868322730064392
epoch: 1 step: 378, loss is 2.545279026031494
epoch: 1 step: 379, loss is 2.482837677001953
epoch: 1 step: 380, loss is 2.2803609371185303
epoch: 1 step: 381, loss is 2.1398262977600098
epoch: 1 step: 382, loss is 2.007966995239258
epoch: 1 step: 383, loss is 2.276697874069214
epoch: 1 step: 384, loss is 2.240429401397705
epoch: 1 step: 385, loss is 2.3643879890441895
epoch: 1 step: 386, loss is 2.018928050994873
epoch: 1 step: 387, loss is 2.137725353240967
epoch: 1 step: 388, loss is 2.158298969268799
epoch: 1 step: 389, loss is 1.769158124923706
epoch: 1 step: 390, loss is 2.034914970397949
epoch: 1 step: 391, loss is 2.1285059452056885
epoch: 1 step: 392, loss is 2.462773561477661
epoch: 1 step: 393, loss is 2.564924478530884
epoch: 1 step: 394, loss is 2.5071792602539062
epoch: 1 step: 395, loss is 2.305281639099121
epoch: 1 step: 396, loss is 1.9280025959014893
epoch: 1 step: 397, loss is 2.2663686275482178
epoch: 1 step: 398, loss is 2.5308475494384766
epoch: 1 step: 399, loss is 2.721531391143799
epoch: 1 step: 400, loss is 2.131120204925537
epoch: 1 step: 401, loss is 2.108673572540283
epoch: 1 step: 402, loss is 2.3213555812835693
epoch: 1 step: 403, loss is 1.9560959339141846
epoch: 1 step: 404, loss is 2.2348721027374268
epoch: 1 step: 405, loss is 2.790985584259033
epoch: 1 step: 406, loss is 2.0200438499450684
epoch: 1 step: 407, loss is 2.230942964553833
epoch: 1 step: 408, loss is 2.444641351699829
epoch: 1 step: 409, loss is 2.2976064682006836
epoch: 1 step: 410, loss is 2.614006996154785
epoch: 1 step: 411, loss is 1.8422152996063232
epoch: 1 step: 412, loss is 2.1191985607147217
epoch: 1 step: 413, loss is 2.3731179237365723
epoch: 1 step: 414, loss is 2.1916840076446533
epoch: 1 step: 415, loss is 2.582117795944214
epoch: 1 step: 416, loss is 1.9467241764068604
epoch: 1 step: 417, loss is 1.8540947437286377
epoch: 1 step: 418, loss is 2.274940252304077
epoch: 1 step: 419, loss is 2.331502914428711
epoch: 1 step: 420, loss is 2.0714128017425537
epoch: 1 step: 421, loss is 2.2246718406677246
epoch: 1 step: 422, loss is 2.1393539905548096
epoch: 1 step: 423, loss is 2.4221982955932617
epoch: 1 step: 424, loss is 2.3887264728546143
epoch: 1 step: 425, loss is 2.282315254211426
epoch: 1 step: 426, loss is 2.3673717975616455
epoch: 1 step: 427, loss is 2.308889150619507
epoch: 1 step: 428, loss is 2.046236038208008
epoch: 1 step: 429, loss is 2.09428334236145
epoch: 1 step: 430, loss is 2.0872511863708496
epoch: 1 step: 431, loss is 2.3781440258026123
epoch: 1 step: 432, loss is 2.269421339035034
epoch: 1 step: 433, loss is 2.1238834857940674
epoch: 1 step: 434, loss is 2.3587095737457275
epoch: 1 step: 435, loss is 2.7772974967956543
epoch: 1 step: 436, loss is 2.8379673957824707
epoch: 1 step: 437, loss is 2.376774549484253
epoch: 1 step: 438, loss is 2.1053237915039062
epoch: 1 step: 439, loss is 1.9341497421264648
epoch: 1 step: 440, loss is 2.109922409057617
epoch: 1 step: 441, loss is 1.9373430013656616
epoch: 1 step: 442, loss is 2.2170746326446533
epoch: 1 step: 443, loss is 2.4009859561920166
epoch: 1 step: 444, loss is 2.5638539791107178
epoch: 1 step: 445, loss is 1.985969066619873
epoch: 1 step: 446, loss is 2.9069111347198486
epoch: 1 step: 447, loss is 2.2156426906585693
epoch: 1 step: 448, loss is 1.9771026372909546
epoch: 1 step: 449, loss is 2.707566976547241
epoch: 1 step: 450, loss is 2.1725211143493652
epoch: 1 step: 451, loss is 2.094482183456421
epoch: 1 step: 452, loss is 2.2152276039123535
epoch: 1 step: 453, loss is 1.8215075731277466
epoch: 1 step: 454, loss is 2.2684712409973145
epoch: 1 step: 455, loss is 2.247671127319336
epoch: 1 step: 456, loss is 2.192174196243286
epoch: 1 step: 457, loss is 2.3436570167541504
epoch: 1 step: 458, loss is 2.286713123321533
epoch: 1 step: 459, loss is 2.0330140590667725
epoch: 1 step: 460, loss is 2.13211727142334
epoch: 1 step: 461, loss is 2.2922303676605225
epoch: 1 step: 462, loss is 2.269904851913452
epoch: 1 step: 463, loss is 2.526247024536133
epoch: 1 step: 464, loss is 2.336387872695923
epoch: 1 step: 465, loss is 2.290205955505371
epoch: 1 step: 466, loss is 1.8469833135604858
epoch: 1 step: 467, loss is 2.172717571258545
epoch: 1 step: 468, loss is 2.410285711288452
epoch: 1 step: 469, loss is 2.633931875228882
epoch: 1 step: 470, loss is 1.9941343069076538
epoch: 1 step: 471, loss is 1.747193694114685
epoch: 1 step: 472, loss is 2.424727201461792
epoch: 1 step: 473, loss is 2.2178268432617188
epoch: 1 step: 474, loss is 1.7849880456924438
epoch: 1 step: 475, loss is 2.6825149059295654
epoch: 1 step: 476, loss is 2.22454571723938
epoch: 1 step: 477, loss is 2.181126117706299
epoch: 1 step: 478, loss is 1.809498906135559
epoch: 1 step: 479, loss is 2.2522993087768555
epoch: 1 step: 480, loss is 1.9627231359481812
epoch: 1 step: 481, loss is 2.407466173171997
epoch: 1 step: 482, loss is 2.633741617202759
epoch: 1 step: 483, loss is 1.97539484500885
epoch: 1 step: 484, loss is 2.111461639404297
epoch: 1 step: 485, loss is 2.0772695541381836
epoch: 1 step: 486, loss is 2.5016844272613525
epoch: 1 step: 487, loss is 2.6679084300994873
epoch: 1 step: 488, loss is 2.14442777633667
epoch: 1 step: 489, loss is 2.147228240966797
epoch: 1 step: 490, loss is 2.048213005065918
epoch: 1 step: 491, loss is 2.4181127548217773
epoch: 1 step: 492, loss is 2.5247011184692383
epoch: 1 step: 493, loss is 2.388942003250122
epoch: 1 step: 494, loss is 1.9163062572479248
epoch: 1 step: 495, loss is 1.9449471235275269
epoch: 1 step: 496, loss is 1.8332639932632446
epoch: 1 step: 497, loss is 2.345304012298584
epoch: 1 step: 498, loss is 2.0195631980895996
epoch: 1 step: 499, loss is 2.543567180633545
epoch: 1 step: 500, loss is 2.0039420127868652
Train epoch time: 576588.721 ms, per step time: 1153.177 ms
total time:0h 9m 36s
============== Train Success ==============

训练好的模型保存在当前目录的shufflenetv1-1_500.ckpt中,用作评估。

模型评估

在CIFAR-10的测试集上对模型进行评估。

设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()接口对模型进行评估。

[9]:

from mindspore import load_checkpoint, load_param_into_net
def test():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="CPU")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 2, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-1_500.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-1_500.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')
if __name__ == '__main__':
    test()
model size is  2.0x
result:{'Loss': 3.9649828687906266, 'Top_1_Acc': 0.2215, 'Top_5_Acc': 0.7055}, ckpt:'./shufflenetv1-1_500.ckpt', time: 0h 3m 23s

模型预测

在CIFAR-10的测试集上对模型进行预测,并将预测结果可视化。

[10]:

 
import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds
net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-1_500.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),
    vision.RandomHorizontalFlip(prob=0.5),
    vision.Resize((224, 224)),
    vision.Rescale(1.0 / 255.0, 0.0),
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    vision.HWC2CHW()
        ]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
    plt.subplot(2, 8, index+1)
    plt.title('{}'.format(class_dict[pred[index]]))
    index += 1
    plt.imshow(image)
    plt.axis("off")
plt.show()
model size is  2.0x

[11]:

 
import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
2024-07-15 10:31:09 guojun0718
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值