简单介绍下图片分类模型的日常用处:例如给定数字图像,判断图像所属的类别:猫、狗、飞机、汽车等等。如果用函数来表示这个过程如下:
def classify(image):
label = model(image)
return label
-
首先需要准备好本次演示使用的CIFAR-10数据集,可从MindSpore官网教程的实现图片分类页面中下载,数据集简介如下。
-
进入本次正题,处理数据集,下面是具体的操作说明和代码实现
先将数据预加载出来和预处理
加载数据集
数据加载可以通过内置数据集格式Cifar10Dataset接口完成。
cifar_ds = ds.Cifar10Dataset(data_home)
数据增强
数据增强主要是对数据进行归一化和丰富数据样本数量,调用map方法在图片上执行增强操作:
resize_height = 224
resize_width = 224
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = C.RandomHorizontalFlip()
resize_op &#