心得
通过这个课程,我们在机器视觉方面,又多了一个武器。
不过,在最前面,先说一些操作经验。
上次提到,说CPU的计算,比Ascend的计算要快稳。这次尝试了一下。果然如此,但是一看运行结果,再仔细看代码。。。
CPU:epoch=1
Ascend:epoch=5
速度是这样子来的啊。
不管怎么样,操作流程是一样的。仅仅学习也是可以的。
打卡截图
ShuffleNet图像分类
当前案例不支持在GPU设备上静态图模式运行,其他模式运行皆支持。
ShuffleNet网络介绍
ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。
了解ShuffleNet更多详细内容,详见论文ShuffleNet。
如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。
图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.
模型架构
ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。
Pointwise Group Convolution
Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量。
图片来源:Huang G, Liu S, Van der Maaten L, et al. Condensenet: An efficient densenet using learned group convolutions[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 2752-2761.
Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels
,然后对每一个in_channels
做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1*k*k,则卷积核参数量为:in_channels*k*k,得到的feature maps通道数与输入通道数相等;
Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×11×1,卷积核参数量为(in_channels/g*1*1)*out_channels。
[1]:
%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0rc1,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
[2]:
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore Version: 2.3.0rc1 Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios. Home-page: https://www.mindspore.cn Author: The MindSpore Authors Author-email: contact@mindspore.cn License: Apache 2.0 Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy Required-by:
[3]:
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor
class GroupConv(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size,
stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
self.groups = groups
self.convs = nn.CellList()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))
def construct(self, x):
features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](features[i].astype("float32")),)
out = ops.cat(outputs, axis=1)
return out
Channel Shuffle
Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。
为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。
如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。
为了阅读方便,将Channel Shuffle的代码实现放在下方ShuffleNet模块的代码中。
ShuffleNet模块
如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), (c)的更改:
-
将开始和最后的1×11×1卷积模块(降维、升维)改成Point Wise Group Convolution;
-
为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;
-
降采样模块中,3×33×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的3×33×3平均池化,并把相加改成拼接。
[4]:
class ShuffleV1Block(nn.Cell):
def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
super(ShuffleV1Block, self).__init__()
self.stride = stride
pad = ksize // 2
self.group = group
if stride == 2:
outputs = oup - inp
else:
outputs = oup
self.relu = nn.ReLU()
branch_main_1 = [
GroupConv(in_channels=inp, out_channels=mid_channels,
kernel_size=1, stride=1, pad_mode="pad", pad=0,
groups=1 if first_group else group),
nn.BatchNorm2d(mid_channels),
nn.ReLU(),
]
branch_main_2 = [
nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
pad_mode='pad', padding=pad, group=mid_channels,
weight_init='xavier_uniform', has_bias=False),
nn.BatchNorm2d(mid_channels),
GroupConv(in_channels=mid_channels, out_channels=outputs,
kernel_size=1, stride=1, pad_mode="pad", pad=0,
groups=group),
nn.BatchNorm2d(outputs),
]
self.branch_main_1 = nn.SequentialCell(branch_main_1)
self.branch_main_2 = nn.SequentialCell(branch_main_2)
if stride == 2:
self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
def construct(self, old_x):
left = old_x
right = old_x
out = old_x
right = self.branch_main_1(right)
if self.group > 1:
right = self.channel_shuffle(right)
right = self.branch_main_2(right)
if self.stride == 1:
out = self.relu(left + right)
elif self.stride == 2:
left = self.branch_proj(left)
out = ops.cat((left, right), 1)
out = self.relu(out)
return out
def channel_shuffle(self, x):
batchsize, num_channels, height, width = ops.shape(x)
group_channels = num_channels // self.group
x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
x = ops.transpose(x, (0, 2, 1, 3, 4))
x = ops.reshape(x, (batchsize, num_channels, height, width))
return x
构建ShuffleNet网络
ShuffleNet网络结构如下图所示,以输入图像224×224224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为3×33×3,stride为2的卷积层,输出特征图大小为112×112112×112,channel为24;然后通过stride为2的最大池化层,输出特征图大小为56×5656×56,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图(c)),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为1×1×9601×1×960,再经过全连接层和softmax,得到分类概率。
[5]:
class ShuffleNetV1(nn.Cell):
def __init__(self, n_class=1000, model_size='2.0x', group=3):
super(ShuffleNetV1, self).__init__()
print('model size is ', model_size)
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
if group == 3:
if model_size == '0.5x':
self.stage_out_channels = [-1, 12, 120, 240, 480]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 240, 480, 960]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 360, 720, 1440]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 480, 960, 1920]
else:
raise NotImplementedError
elif group == 8:
if model_size == '0.5x':
self.stage_out_channels = [-1, 16, 192, 384, 768]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 384, 768, 1536]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 576, 1152, 2304]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 48, 768, 1536, 3072]
else:
raise NotImplementedError
input_channel = self.stage_out_channels[1]
self.first_conv = nn.SequentialCell(
nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
nn.BatchNorm2d(input_channel),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
features = []
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
stride = 2 if i == 0 else 1
first_group = idxstage == 0 and i == 0
features.append(ShuffleV1Block(input_channel, output_channel,
group=group, first_group=first_group,
mid_channels=output_channel // 4, ksize=3, stride=stride))
input_channel = output_channel
self.features = nn.SequentialCell(features)
self.globalpool = nn.AvgPool2d(7)
self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)
def construct(self, x):
x = self.first_conv(x)
x = self.maxpool(x)
x = self.features(x)
x = self.globalpool(x)
x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
x = self.classifier(x)
return x
模型训练和评估
采用CIFAR-10数据集对ShuffleNet进行预训练。
训练集准备与加载
采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset
接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。
[6]:
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)
Creating data folder... Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB) file_sizes: 100%|████████████████████████████| 170M/170M [00:01<00:00, 99.4MB/s] Extracting tar.gz file... Successfully downloaded / unzipped to ./dataset
[6]:
'./dataset'
由于时间原因,对部分参数进行了调整(num_samples,batch_size..),导致训练的模型效果不好,建议根据需求进行调整,建议在 Ascend/GPU 环境体验该教程。
原教程及参数,请参考ShuffleNet图像分类。
[7]:
import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms
def get_dataset(train_dataset_path, batch_size, usage):
image_trans = []
if usage == "train":
image_trans = [
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
elif usage == "test":
image_trans = [
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
label_trans = transforms.TypeCast(ms.int32)
dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True, num_samples=2000)
dataset = dataset.map(image_trans, 'image')
dataset = dataset.map(label_trans, 'label')
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
dataset = get_dataset("./dataset/cifar-10-batches-bin", 4, "train")
batches_per_epoch = dataset.get_dataset_size()
模型训练
本节用随机初始化的参数做预训练。首先调用ShuffleNetV1
定义网络,参数量选择"2.0x"
,并定义损失函数为交叉熵损失,学习率经过4轮的warmup
后采用余弦退火,优化器采用Momentum
。最后用train.model
中的Model
接口将模型、损失函数、优化器封装在model
中,并用model.train()
对网络进行训练。将ModelCheckpoint
、CheckpointConfig
、TimeMonitor
和LossMonitor
传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。
[8]:
import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy
def train():
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="CPU")
net = ShuffleNetV1(model_size="2.0x", n_class=10)
loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
min_lr = 0.0005
base_lr = 0.05
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
base_lr,
batches_per_epoch*2,
batches_per_epoch,
decay_epoch=2)
lr = Tensor(lr_scheduler[-1])
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
callback = [TimeMonitor(), LossMonitor()]
save_ckpt_path = "./"
config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
callback += [ckpt_callback]
print("============== Starting Training ==============")
start_time = time.time()
model.train(1, dataset, callbacks=callback)
use_time = time.time() - start_time
hour = str(int(use_time // 60 // 60))
minute = str(int(use_time // 60 % 60))
second = str(int(use_time % 60))
print("total time:" + hour + "h " + minute + "m " + second + "s")
print("============== Train Success ==============")
if __name__ == '__main__':
train()
model size is 2.0x ============== Starting Training ============== epoch: 1 step: 1, loss is 2.6258254051208496 epoch: 1 step: 2, loss is 39.794734954833984 epoch: 1 step: 3, loss is 44.542945861816406 epoch: 1 step: 4, loss is 75.65628814697266 epoch: 1 step: 5, loss is 127.48133850097656 epoch: 1 step: 6, loss is 87.5020751953125 epoch: 1 step: 7, loss is 101.62273406982422 epoch: 1 step: 8, loss is 155.16738891601562 epoch: 1 step: 9, loss is 241.32371520996094 epoch: 1 step: 10, loss is 227.01644897460938 epoch: 1 step: 11, loss is 138.65396118164062 epoch: 1 step: 12, loss is 176.27835083007812 epoch: 1 step: 13, loss is 307.84295654296875 epoch: 1 step: 14, loss is 78.77322387695312 epoch: 1 step: 15, loss is 73.79924774169922 epoch: 1 step: 16, loss is 104.46890258789062 epoch: 1 step: 17, loss is 42.950660705566406 epoch: 1 step: 18, loss is 65.44954681396484 epoch: 1 step: 19, loss is 90.81419372558594 epoch: 1 step: 20, loss is 56.11922836303711 epoch: 1 step: 21, loss is 23.536577224731445 epoch: 1 step: 22, loss is 47.05916213989258 epoch: 1 step: 23, loss is 100.97496795654297 epoch: 1 step: 24, loss is 144.5192108154297 epoch: 1 step: 25, loss is 93.51004791259766 epoch: 1 step: 26, loss is 50.41395568847656 epoch: 1 step: 27, loss is 15.169258117675781 epoch: 1 step: 28, loss is 49.37852478027344 epoch: 1 step: 29, loss is 39.40176773071289 epoch: 1 step: 30, loss is 20.11125373840332 epoch: 1 step: 31, loss is 19.932083129882812 epoch: 1 step: 32, loss is 25.057092666625977 epoch: 1 step: 33, loss is 14.10660171508789 epoch: 1 step: 34, loss is 26.354232788085938 epoch: 1 step: 35, loss is 6.359532356262207 epoch: 1 step: 36, loss is 12.980948448181152 epoch: 1 step: 37, loss is 26.174339294433594 epoch: 1 step: 38, loss is 13.695343971252441 epoch: 1 step: 39, loss is 29.925888061523438 epoch: 1 step: 40, loss is 8.584239959716797 epoch: 1 step: 41, loss is 21.931264877319336 epoch: 1 step: 42, loss is 12.06672477722168 epoch: 1 step: 43, loss is 22.264915466308594 epoch: 1 step: 44, loss is 4.067727088928223 epoch: 1 step: 45, loss is 6.653378486633301 epoch: 1 step: 46, loss is 10.231215476989746 epoch: 1 step: 47, loss is 2.7845568656921387 epoch: 1 step: 48, loss is 5.390234470367432 epoch: 1 step: 49, loss is 3.0019350051879883 epoch: 1 step: 50, loss is 3.331113576889038 epoch: 1 step: 51, loss is 2.9404754638671875 epoch: 1 step: 52, loss is 2.819676160812378 epoch: 1 step: 53, loss is 2.6152536869049072 epoch: 1 step: 54, loss is 4.664990425109863 epoch: 1 step: 55, loss is 3.8249380588531494 epoch: 1 step: 56, loss is 2.2946901321411133 epoch: 1 step: 57, loss is 5.391695976257324 epoch: 1 step: 58, loss is 2.6904654502868652 epoch: 1 step: 59, loss is 2.123711585998535 epoch: 1 step: 60, loss is 2.2266998291015625 epoch: 1 step: 61, loss is 3.560650587081909 epoch: 1 step: 62, loss is 10.60955810546875 epoch: 1 step: 63, loss is 12.901015281677246 epoch: 1 step: 64, loss is 11.2257080078125 epoch: 1 step: 65, loss is 6.704418182373047 epoch: 1 step: 66, loss is 4.425196170806885 epoch: 1 step: 67, loss is 7.6208319664001465 epoch: 1 step: 68, loss is 3.637984275817871 epoch: 1 step: 69, loss is 7.687455654144287 epoch: 1 step: 70, loss is 5.50910758972168 epoch: 1 step: 71, loss is 6.786325454711914 epoch: 1 step: 72, loss is 3.295285940170288 epoch: 1 step: 73, loss is 6.319259166717529 epoch: 1 step: 74, loss is 4.665078639984131 epoch: 1 step: 75, loss is 3.1605069637298584 epoch: 1 step: 76, loss is 3.25880765914917 epoch: 1 step: 77, loss is 1.940742015838623 epoch: 1 step: 78, loss is 4.281459808349609 epoch: 1 step: 79, loss is 2.718125104904175 epoch: 1 step: 80, loss is 3.1851389408111572 epoch: 1 step: 81, loss is 5.275569915771484 epoch: 1 step: 82, loss is 3.299849271774292 epoch: 1 step: 83, loss is 5.063135147094727 epoch: 1 step: 84, loss is 2.1495776176452637 epoch: 1 step: 85, loss is 5.728967666625977 epoch: 1 step: 86, loss is 4.1019206047058105 epoch: 1 step: 87, loss is 3.6750073432922363 epoch: 1 step: 88, loss is 3.359067440032959 epoch: 1 step: 89, loss is 2.8257193565368652 epoch: 1 step: 90, loss is 5.062455654144287 epoch: 1 step: 91, loss is 4.298058986663818 epoch: 1 step: 92, loss is 2.795989513397217 epoch: 1 step: 93, loss is 3.003056287765503 epoch: 1 step: 94, loss is 2.682116985321045 epoch: 1 step: 95, loss is 3.50089168548584 epoch: 1 step: 96, loss is 3.6953463554382324 epoch: 1 step: 97, loss is 3.617540121078491 epoch: 1 step: 98, loss is 3.163033962249756 epoch: 1 step: 99, loss is 2.6230764389038086 epoch: 1 step: 100, loss is 2.1377086639404297 epoch: 1 step: 101, loss is 2.177856206893921 epoch: 1 step: 102, loss is 2.5667147636413574 epoch: 1 step: 103, loss is 2.4466171264648438 epoch: 1 step: 104, loss is 4.172312259674072 epoch: 1 step: 105, loss is 2.986344814300537 epoch: 1 step: 106, loss is 3.572352886199951 epoch: 1 step: 107, loss is 3.6122045516967773 epoch: 1 step: 108, loss is 2.0893056392669678 epoch: 1 step: 109, loss is 2.4502604007720947 epoch: 1 step: 110, loss is 2.1798033714294434 epoch: 1 step: 111, loss is 2.139862537384033 epoch: 1 step: 112, loss is 2.40075945854187 epoch: 1 step: 113, loss is 2.257197618484497 epoch: 1 step: 114, loss is 2.3176167011260986 epoch: 1 step: 115, loss is 1.8200316429138184 epoch: 1 step: 116, loss is 2.790212631225586 epoch: 1 step: 117, loss is 3.8091580867767334 epoch: 1 step: 118, loss is 2.492215633392334 epoch: 1 step: 119, loss is 3.1819939613342285 epoch: 1 step: 120, loss is 2.3661324977874756 epoch: 1 step: 121, loss is 2.6036696434020996 epoch: 1 step: 122, loss is 2.20636248588562 epoch: 1 step: 123, loss is 2.164297580718994 epoch: 1 step: 124, loss is 1.981563687324524 epoch: 1 step: 125, loss is 1.9817254543304443 epoch: 1 step: 126, loss is 2.7873992919921875 epoch: 1 step: 127, loss is 2.66616153717041 epoch: 1 step: 128, loss is 3.3393306732177734 epoch: 1 step: 129, loss is 2.975348949432373 epoch: 1 step: 130, loss is 1.8029996156692505 epoch: 1 step: 131, loss is 3.2323436737060547 epoch: 1 step: 132, loss is 3.095515489578247 epoch: 1 step: 133, loss is 1.7939496040344238 epoch: 1 step: 134, loss is 2.595576763153076 epoch: 1 step: 135, loss is 2.4362661838531494 epoch: 1 step: 136, loss is 2.508674144744873 epoch: 1 step: 137, loss is 2.927971839904785 epoch: 1 step: 138, loss is 2.006402015686035 epoch: 1 step: 139, loss is 2.460578680038452 epoch: 1 step: 140, loss is 2.9014501571655273 epoch: 1 step: 141, loss is 1.815853238105774 epoch: 1 step: 142, loss is 2.3019559383392334 epoch: 1 step: 143, loss is 2.4128715991973877 epoch: 1 step: 144, loss is 3.326815605163574 epoch: 1 step: 145, loss is 2.200629234313965 epoch: 1 step: 146, loss is 2.0387496948242188 epoch: 1 step: 147, loss is 2.707109212875366 epoch: 1 step: 148, loss is 2.6215920448303223 epoch: 1 step: 149, loss is 2.3630850315093994 epoch: 1 step: 150, loss is 2.377840042114258 epoch: 1 step: 151, loss is 2.550588607788086 epoch: 1 step: 152, loss is 2.06246018409729 epoch: 1 step: 153, loss is 2.3065648078918457 epoch: 1 step: 154, loss is 2.236825466156006 epoch: 1 step: 155, loss is 2.632615804672241 epoch: 1 step: 156, loss is 2.2588086128234863 epoch: 1 step: 157, loss is 2.6022658348083496 epoch: 1 step: 158, loss is 2.410961389541626 epoch: 1 step: 159, loss is 1.991721272468567 epoch: 1 step: 160, loss is 2.3659610748291016 epoch: 1 step: 161, loss is 1.981337547302246 epoch: 1 step: 162, loss is 2.2522175312042236 epoch: 1 step: 163, loss is 2.3218607902526855 epoch: 1 step: 164, loss is 2.272987127304077 epoch: 1 step: 165, loss is 2.4044132232666016 epoch: 1 step: 166, loss is 2.0630550384521484 epoch: 1 step: 167, loss is 2.469780921936035 epoch: 1 step: 168, loss is 2.350719928741455 epoch: 1 step: 169, loss is 2.135317325592041 epoch: 1 step: 170, loss is 2.5933918952941895 epoch: 1 step: 171, loss is 2.3896148204803467 epoch: 1 step: 172, loss is 2.3143160343170166 epoch: 1 step: 173, loss is 1.9470417499542236 epoch: 1 step: 174, loss is 2.2532448768615723 epoch: 1 step: 175, loss is 2.1132826805114746 epoch: 1 step: 176, loss is 2.317906379699707 epoch: 1 step: 177, loss is 2.7141761779785156 epoch: 1 step: 178, loss is 2.1723790168762207 epoch: 1 step: 179, loss is 2.565363645553589 epoch: 1 step: 180, loss is 2.4471046924591064 epoch: 1 step: 181, loss is 2.376852035522461 epoch: 1 step: 182, loss is 1.9931414127349854 epoch: 1 step: 183, loss is 2.525416851043701 epoch: 1 step: 184, loss is 1.9096590280532837 epoch: 1 step: 185, loss is 2.028379201889038 epoch: 1 step: 186, loss is 1.9379422664642334 epoch: 1 step: 187, loss is 2.609530448913574 epoch: 1 step: 188, loss is 1.987502098083496 epoch: 1 step: 189, loss is 2.467780351638794 epoch: 1 step: 190, loss is 2.179129123687744 epoch: 1 step: 191, loss is 2.342135190963745 epoch: 1 step: 192, loss is 2.4540958404541016 epoch: 1 step: 193, loss is 1.7237451076507568 epoch: 1 step: 194, loss is 2.3453142642974854 epoch: 1 step: 195, loss is 2.4298312664031982 epoch: 1 step: 196, loss is 2.6432907581329346 epoch: 1 step: 197, loss is 2.034980535507202 epoch: 1 step: 198, loss is 1.9781975746154785 epoch: 1 step: 199, loss is 2.5176000595092773 epoch: 1 step: 200, loss is 2.403144598007202 epoch: 1 step: 201, loss is 3.2028563022613525 epoch: 1 step: 202, loss is 2.2373926639556885 epoch: 1 step: 203, loss is 2.0563879013061523 epoch: 1 step: 204, loss is 2.3786404132843018 epoch: 1 step: 205, loss is 2.0180726051330566 epoch: 1 step: 206, loss is 1.694482684135437 epoch: 1 step: 207, loss is 2.325190305709839 epoch: 1 step: 208, loss is 2.433765172958374 epoch: 1 step: 209, loss is 2.7312464714050293 epoch: 1 step: 210, loss is 1.947196125984192 epoch: 1 step: 211, loss is 1.9539406299591064 epoch: 1 step: 212, loss is 2.4014337062835693 epoch: 1 step: 213, loss is 2.3533763885498047 epoch: 1 step: 214, loss is 1.6875959634780884 epoch: 1 step: 215, loss is 2.9546163082122803 epoch: 1 step: 216, loss is 2.352206230163574 epoch: 1 step: 217, loss is 2.22822904586792 epoch: 1 step: 218, loss is 1.7404886484146118 epoch: 1 step: 219, loss is 2.425816297531128 epoch: 1 step: 220, loss is 2.098177433013916 epoch: 1 step: 221, loss is 2.1560072898864746 epoch: 1 step: 222, loss is 2.641177177429199 epoch: 1 step: 223, loss is 3.201406717300415 epoch: 1 step: 224, loss is 1.9623440504074097 epoch: 1 step: 225, loss is 2.405972719192505 epoch: 1 step: 226, loss is 1.9641368389129639 epoch: 1 step: 227, loss is 2.1503448486328125 epoch: 1 step: 228, loss is 2.583299160003662 epoch: 1 step: 229, loss is 2.2811355590820312 epoch: 1 step: 230, loss is 2.3087921142578125 epoch: 1 step: 231, loss is 3.0039772987365723 epoch: 1 step: 232, loss is 2.1656908988952637 epoch: 1 step: 233, loss is 2.2291247844696045 epoch: 1 step: 234, loss is 2.264976739883423 epoch: 1 step: 235, loss is 2.2253987789154053 epoch: 1 step: 236, loss is 2.0790657997131348 epoch: 1 step: 237, loss is 2.168771743774414 epoch: 1 step: 238, loss is 2.28842830657959 epoch: 1 step: 239, loss is 2.1632187366485596 epoch: 1 step: 240, loss is 2.036649703979492 epoch: 1 step: 241, loss is 2.153747797012329 epoch: 1 step: 242, loss is 2.152430772781372 epoch: 1 step: 243, loss is 2.5319528579711914 epoch: 1 step: 244, loss is 2.431246042251587 epoch: 1 step: 245, loss is 2.685051441192627 epoch: 1 step: 246, loss is 2.087167739868164 epoch: 1 step: 247, loss is 2.0803401470184326 epoch: 1 step: 248, loss is 2.17630672454834 epoch: 1 step: 249, loss is 2.0943384170532227 epoch: 1 step: 250, loss is 2.2290658950805664 epoch: 1 step: 251, loss is 2.9906554222106934 epoch: 1 step: 252, loss is 2.0231692790985107 epoch: 1 step: 253, loss is 2.010467767715454 epoch: 1 step: 254, loss is 2.873980760574341 epoch: 1 step: 255, loss is 2.346977949142456 epoch: 1 step: 256, loss is 2.405374050140381 epoch: 1 step: 257, loss is 2.332207202911377 epoch: 1 step: 258, loss is 2.884948492050171 epoch: 1 step: 259, loss is 2.048842191696167 epoch: 1 step: 260, loss is 2.1544370651245117 epoch: 1 step: 261, loss is 2.6596505641937256 epoch: 1 step: 262, loss is 2.32731032371521 epoch: 1 step: 263, loss is 2.090808391571045 epoch: 1 step: 264, loss is 2.712474822998047 epoch: 1 step: 265, loss is 2.326045274734497 epoch: 1 step: 266, loss is 2.2843425273895264 epoch: 1 step: 267, loss is 2.4332685470581055 epoch: 1 step: 268, loss is 2.0449941158294678 epoch: 1 step: 269, loss is 2.365168571472168 epoch: 1 step: 270, loss is 2.3919451236724854 epoch: 1 step: 271, loss is 1.9652740955352783 epoch: 1 step: 272, loss is 1.9539532661437988 epoch: 1 step: 273, loss is 2.4892454147338867 epoch: 1 step: 274, loss is 2.0492358207702637 epoch: 1 step: 275, loss is 2.3324999809265137 epoch: 1 step: 276, loss is 2.4011893272399902 epoch: 1 step: 277, loss is 2.118913173675537 epoch: 1 step: 278, loss is 2.049823045730591 epoch: 1 step: 279, loss is 2.1467506885528564 epoch: 1 step: 280, loss is 2.1704981327056885 epoch: 1 step: 281, loss is 2.1933822631835938 epoch: 1 step: 282, loss is 2.1281886100769043 epoch: 1 step: 283, loss is 2.3650872707366943 epoch: 1 step: 284, loss is 2.231043577194214 epoch: 1 step: 285, loss is 1.9875519275665283 epoch: 1 step: 286, loss is 2.15653133392334 epoch: 1 step: 287, loss is 2.5786375999450684 epoch: 1 step: 288, loss is 2.3581206798553467 epoch: 1 step: 289, loss is 1.7174850702285767 epoch: 1 step: 290, loss is 1.8790639638900757 epoch: 1 step: 291, loss is 2.380164861679077 epoch: 1 step: 292, loss is 2.1249706745147705 epoch: 1 step: 293, loss is 2.1836869716644287 epoch: 1 step: 294, loss is 2.7065696716308594 epoch: 1 step: 295, loss is 2.3943002223968506 epoch: 1 step: 296, loss is 2.286813497543335 epoch: 1 step: 297, loss is 1.9807648658752441 epoch: 1 step: 298, loss is 2.0836973190307617 epoch: 1 step: 299, loss is 1.9861990213394165 epoch: 1 step: 300, loss is 1.9967654943466187 epoch: 1 step: 301, loss is 1.9926495552062988 epoch: 1 step: 302, loss is 2.073901653289795 epoch: 1 step: 303, loss is 1.9408005475997925 epoch: 1 step: 304, loss is 2.432864189147949 epoch: 1 step: 305, loss is 2.2370548248291016 epoch: 1 step: 306, loss is 2.39937424659729 epoch: 1 step: 307, loss is 2.391186237335205 epoch: 1 step: 308, loss is 2.6636369228363037 epoch: 1 step: 309, loss is 2.2630040645599365 epoch: 1 step: 310, loss is 3.0700266361236572 epoch: 1 step: 311, loss is 2.2472522258758545 epoch: 1 step: 312, loss is 2.138871192932129 epoch: 1 step: 313, loss is 2.2199056148529053 epoch: 1 step: 314, loss is 2.289818048477173 epoch: 1 step: 315, loss is 2.0501785278320312 epoch: 1 step: 316, loss is 2.20554256439209 epoch: 1 step: 317, loss is 1.8506948947906494 epoch: 1 step: 318, loss is 2.431769371032715 epoch: 1 step: 319, loss is 2.232560634613037 epoch: 1 step: 320, loss is 2.251526355743408 epoch: 1 step: 321, loss is 2.6219663619995117 epoch: 1 step: 322, loss is 2.520749092102051 epoch: 1 step: 323, loss is 2.5345115661621094 epoch: 1 step: 324, loss is 2.373814105987549 epoch: 1 step: 325, loss is 2.5523581504821777 epoch: 1 step: 326, loss is 1.8571141958236694 epoch: 1 step: 327, loss is 2.4886703491210938 epoch: 1 step: 328, loss is 2.301044225692749 epoch: 1 step: 329, loss is 2.152078628540039 epoch: 1 step: 330, loss is 2.2027933597564697 epoch: 1 step: 331, loss is 1.9776010513305664 epoch: 1 step: 332, loss is 2.523000717163086 epoch: 1 step: 333, loss is 1.7177644968032837 epoch: 1 step: 334, loss is 2.4049694538116455 epoch: 1 step: 335, loss is 2.422466516494751 epoch: 1 step: 336, loss is 2.44610333442688 epoch: 1 step: 337, loss is 2.297590494155884 epoch: 1 step: 338, loss is 2.2462000846862793 epoch: 1 step: 339, loss is 2.4756531715393066 epoch: 1 step: 340, loss is 2.180751323699951 epoch: 1 step: 341, loss is 2.196174144744873 epoch: 1 step: 342, loss is 2.0754973888397217 epoch: 1 step: 343, loss is 2.373931884765625 epoch: 1 step: 344, loss is 2.4620656967163086 epoch: 1 step: 345, loss is 2.461794137954712 epoch: 1 step: 346, loss is 2.27624773979187 epoch: 1 step: 347, loss is 2.274203062057495 epoch: 1 step: 348, loss is 3.3738486766815186 epoch: 1 step: 349, loss is 2.6414642333984375 epoch: 1 step: 350, loss is 2.34397029876709 epoch: 1 step: 351, loss is 2.1481738090515137 epoch: 1 step: 352, loss is 2.7757787704467773 epoch: 1 step: 353, loss is 2.6019771099090576 epoch: 1 step: 354, loss is 1.9724563360214233 epoch: 1 step: 355, loss is 2.566990375518799 epoch: 1 step: 356, loss is 2.2047133445739746 epoch: 1 step: 357, loss is 2.2065536975860596 epoch: 1 step: 358, loss is 2.4668900966644287 epoch: 1 step: 359, loss is 2.427429437637329 epoch: 1 step: 360, loss is 2.1901979446411133 epoch: 1 step: 361, loss is 2.2420694828033447 epoch: 1 step: 362, loss is 2.4124679565429688 epoch: 1 step: 363, loss is 2.182917594909668 epoch: 1 step: 364, loss is 2.4764533042907715 epoch: 1 step: 365, loss is 2.389598846435547 epoch: 1 step: 366, loss is 2.3450076580047607 epoch: 1 step: 367, loss is 2.284268379211426 epoch: 1 step: 368, loss is 2.1783335208892822 epoch: 1 step: 369, loss is 2.445622205734253 epoch: 1 step: 370, loss is 2.401132106781006 epoch: 1 step: 371, loss is 2.1972923278808594 epoch: 1 step: 372, loss is 2.389873743057251 epoch: 1 step: 373, loss is 2.080435037612915 epoch: 1 step: 374, loss is 2.4331464767456055 epoch: 1 step: 375, loss is 2.0456268787384033 epoch: 1 step: 376, loss is 2.343022346496582 epoch: 1 step: 377, loss is 1.868322730064392 epoch: 1 step: 378, loss is 2.545279026031494 epoch: 1 step: 379, loss is 2.482837677001953 epoch: 1 step: 380, loss is 2.2803609371185303 epoch: 1 step: 381, loss is 2.1398262977600098 epoch: 1 step: 382, loss is 2.007966995239258 epoch: 1 step: 383, loss is 2.276697874069214 epoch: 1 step: 384, loss is 2.240429401397705 epoch: 1 step: 385, loss is 2.3643879890441895 epoch: 1 step: 386, loss is 2.018928050994873 epoch: 1 step: 387, loss is 2.137725353240967 epoch: 1 step: 388, loss is 2.158298969268799 epoch: 1 step: 389, loss is 1.769158124923706 epoch: 1 step: 390, loss is 2.034914970397949 epoch: 1 step: 391, loss is 2.1285059452056885 epoch: 1 step: 392, loss is 2.462773561477661 epoch: 1 step: 393, loss is 2.564924478530884 epoch: 1 step: 394, loss is 2.5071792602539062 epoch: 1 step: 395, loss is 2.305281639099121 epoch: 1 step: 396, loss is 1.9280025959014893 epoch: 1 step: 397, loss is 2.2663686275482178 epoch: 1 step: 398, loss is 2.5308475494384766 epoch: 1 step: 399, loss is 2.721531391143799 epoch: 1 step: 400, loss is 2.131120204925537 epoch: 1 step: 401, loss is 2.108673572540283 epoch: 1 step: 402, loss is 2.3213555812835693 epoch: 1 step: 403, loss is 1.9560959339141846 epoch: 1 step: 404, loss is 2.2348721027374268 epoch: 1 step: 405, loss is 2.790985584259033 epoch: 1 step: 406, loss is 2.0200438499450684 epoch: 1 step: 407, loss is 2.230942964553833 epoch: 1 step: 408, loss is 2.444641351699829 epoch: 1 step: 409, loss is 2.2976064682006836 epoch: 1 step: 410, loss is 2.614006996154785 epoch: 1 step: 411, loss is 1.8422152996063232 epoch: 1 step: 412, loss is 2.1191985607147217 epoch: 1 step: 413, loss is 2.3731179237365723 epoch: 1 step: 414, loss is 2.1916840076446533 epoch: 1 step: 415, loss is 2.582117795944214 epoch: 1 step: 416, loss is 1.9467241764068604 epoch: 1 step: 417, loss is 1.8540947437286377 epoch: 1 step: 418, loss is 2.274940252304077 epoch: 1 step: 419, loss is 2.331502914428711 epoch: 1 step: 420, loss is 2.0714128017425537 epoch: 1 step: 421, loss is 2.2246718406677246 epoch: 1 step: 422, loss is 2.1393539905548096 epoch: 1 step: 423, loss is 2.4221982955932617 epoch: 1 step: 424, loss is 2.3887264728546143 epoch: 1 step: 425, loss is 2.282315254211426 epoch: 1 step: 426, loss is 2.3673717975616455 epoch: 1 step: 427, loss is 2.308889150619507 epoch: 1 step: 428, loss is 2.046236038208008 epoch: 1 step: 429, loss is 2.09428334236145 epoch: 1 step: 430, loss is 2.0872511863708496 epoch: 1 step: 431, loss is 2.3781440258026123 epoch: 1 step: 432, loss is 2.269421339035034 epoch: 1 step: 433, loss is 2.1238834857940674 epoch: 1 step: 434, loss is 2.3587095737457275 epoch: 1 step: 435, loss is 2.7772974967956543 epoch: 1 step: 436, loss is 2.8379673957824707 epoch: 1 step: 437, loss is 2.376774549484253 epoch: 1 step: 438, loss is 2.1053237915039062 epoch: 1 step: 439, loss is 1.9341497421264648 epoch: 1 step: 440, loss is 2.109922409057617 epoch: 1 step: 441, loss is 1.9373430013656616 epoch: 1 step: 442, loss is 2.2170746326446533 epoch: 1 step: 443, loss is 2.4009859561920166 epoch: 1 step: 444, loss is 2.5638539791107178 epoch: 1 step: 445, loss is 1.985969066619873 epoch: 1 step: 446, loss is 2.9069111347198486 epoch: 1 step: 447, loss is 2.2156426906585693 epoch: 1 step: 448, loss is 1.9771026372909546 epoch: 1 step: 449, loss is 2.707566976547241 epoch: 1 step: 450, loss is 2.1725211143493652 epoch: 1 step: 451, loss is 2.094482183456421 epoch: 1 step: 452, loss is 2.2152276039123535 epoch: 1 step: 453, loss is 1.8215075731277466 epoch: 1 step: 454, loss is 2.2684712409973145 epoch: 1 step: 455, loss is 2.247671127319336 epoch: 1 step: 456, loss is 2.192174196243286 epoch: 1 step: 457, loss is 2.3436570167541504 epoch: 1 step: 458, loss is 2.286713123321533 epoch: 1 step: 459, loss is 2.0330140590667725 epoch: 1 step: 460, loss is 2.13211727142334 epoch: 1 step: 461, loss is 2.2922303676605225 epoch: 1 step: 462, loss is 2.269904851913452 epoch: 1 step: 463, loss is 2.526247024536133 epoch: 1 step: 464, loss is 2.336387872695923 epoch: 1 step: 465, loss is 2.290205955505371 epoch: 1 step: 466, loss is 1.8469833135604858 epoch: 1 step: 467, loss is 2.172717571258545 epoch: 1 step: 468, loss is 2.410285711288452 epoch: 1 step: 469, loss is 2.633931875228882 epoch: 1 step: 470, loss is 1.9941343069076538 epoch: 1 step: 471, loss is 1.747193694114685 epoch: 1 step: 472, loss is 2.424727201461792 epoch: 1 step: 473, loss is 2.2178268432617188 epoch: 1 step: 474, loss is 1.7849880456924438 epoch: 1 step: 475, loss is 2.6825149059295654 epoch: 1 step: 476, loss is 2.22454571723938 epoch: 1 step: 477, loss is 2.181126117706299 epoch: 1 step: 478, loss is 1.809498906135559 epoch: 1 step: 479, loss is 2.2522993087768555 epoch: 1 step: 480, loss is 1.9627231359481812 epoch: 1 step: 481, loss is 2.407466173171997 epoch: 1 step: 482, loss is 2.633741617202759 epoch: 1 step: 483, loss is 1.97539484500885 epoch: 1 step: 484, loss is 2.111461639404297 epoch: 1 step: 485, loss is 2.0772695541381836 epoch: 1 step: 486, loss is 2.5016844272613525 epoch: 1 step: 487, loss is 2.6679084300994873 epoch: 1 step: 488, loss is 2.14442777633667 epoch: 1 step: 489, loss is 2.147228240966797 epoch: 1 step: 490, loss is 2.048213005065918 epoch: 1 step: 491, loss is 2.4181127548217773 epoch: 1 step: 492, loss is 2.5247011184692383 epoch: 1 step: 493, loss is 2.388942003250122 epoch: 1 step: 494, loss is 1.9163062572479248 epoch: 1 step: 495, loss is 1.9449471235275269 epoch: 1 step: 496, loss is 1.8332639932632446 epoch: 1 step: 497, loss is 2.345304012298584 epoch: 1 step: 498, loss is 2.0195631980895996 epoch: 1 step: 499, loss is 2.543567180633545 epoch: 1 step: 500, loss is 2.0039420127868652 Train epoch time: 576588.721 ms, per step time: 1153.177 ms total time:0h 9m 36s ============== Train Success ==============
训练好的模型保存在当前目录的shufflenetv1-1_500.ckpt
中,用作评估。
模型评估
在CIFAR-10的测试集上对模型进行评估。
设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()
接口对模型进行评估。
[9]:
from mindspore import load_checkpoint, load_param_into_net
def test():
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="CPU")
dataset = get_dataset("./dataset/cifar-10-batches-bin", 2, "test")
net = ShuffleNetV1(model_size="2.0x", n_class=10)
param_dict = load_checkpoint("shufflenetv1-1_500.ckpt")
load_param_into_net(net, param_dict)
net.set_train(False)
loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
'Top_5_Acc': Top5CategoricalAccuracy()}
model = Model(net, loss_fn=loss, metrics=eval_metrics)
start_time = time.time()
res = model.eval(dataset, dataset_sink_mode=False)
use_time = time.time() - start_time
hour = str(int(use_time // 60 // 60))
minute = str(int(use_time // 60 % 60))
second = str(int(use_time % 60))
log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-1_500.ckpt" \
+ "', time: " + hour + "h " + minute + "m " + second + "s"
print(log)
filename = './eval_log.txt'
with open(filename, 'a') as file_object:
file_object.write(log + '\n')
if __name__ == '__main__':
test()
model size is 2.0x result:{'Loss': 3.9649828687906266, 'Top_1_Acc': 0.2215, 'Top_5_Acc': 0.7055}, ckpt:'./shufflenetv1-1_500.ckpt', time: 0h 3m 23s
模型预测
在CIFAR-10的测试集上对模型进行预测,并将预测结果可视化。
[10]:
import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds
net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-1_500.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Resize((224, 224)),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
vision.HWC2CHW()
]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
plt.subplot(2, 8, index+1)
plt.title('{}'.format(class_dict[pred[index]]))
index += 1
plt.imshow(image)
plt.axis("off")
plt.show()
model size is 2.0x
[11]:
import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
2024-07-15 10:31:09 guojun0718