之前我们所有项目的输出都是单输出,比如我们的猫狗分类,输出的仅仅为图像中的动物是猫还是狗,我们下面来搞多输出模型,举个例子,多输出的意思是在模型判定图像是猫还是狗的同时,它也会输出比如动物的毛色等其他相关信息
我们先看一下我们的数据集,数据集名称为muti_outputs_dataset
文件夹分别有七种物品,分别是 黑牛仔裤,黑鞋子,蓝裙子,蓝牛仔裤,蓝衬衫,红裙子,红短袖,数据集中共有2525张图片
在这次的训练中,我们不仅要识别出物体的种类,也要识别出物体的颜色
目录
1 导入库
2 处理路径
2.1 获取所有图片路径
首先我们获取全部图片的路径
之后我们获取照片的数量
再之后我们要对图像路径做乱序,我们这里是一定要做乱序的,因为我们后面会用到take与skip提取数据集,不做乱序就会导致测试集都是红鞋子而训练集都是黑鞋子这种情况(它看都没看过红色,然后拿出红色让它分辨是什么颜色,这样它预测的结果就非常不准),当然如果在take与skip前就做数据集的shuffle,在这里就没有必要乱序了
random.shuffle()的返回值是None,所以我们没有必要赋值
2.2 获取标签种类
之后我们获取所有标签
我们显示出来看一下
我们要同时预测颜色和种类,所以我们要分开label_names
添加进列表之后,将列表转换为set格式
- set格式的基本用法我在这个文章中有提到 python set 基本用法_potato123232的博客-CSDN博客
我们显示出来看一下
发现有三种颜色,四种物品,下面我们将颜色与物品进行编号
我们这个编码每一次运行出来的结果有可能会不一样(上面使用set的原因),我们要记录当前的值,以便我们预测时好预测
2.3 获取所有图片的对应标签
之后再把颜色标签与物品类型标签提取出来,并转换为编号
3 定义加载图片函数
我们这里的尺寸不再使用之前的256,因为我们后面要使用到mobilenet预训练模型,mobilenet的标准输出为224*224,我们按照它原来的来会更好一点
4 创建数据集
4.1 创建图片数据集
先创建路径数据集,然后创建图片数据集
4.2 创建标签数据集
我们的标签数据集包含颜色和类型两个部分,也就是说一个图片对应两个标签
4.3 将图片与标签结合起来
我们将这个数据集显示出来看一下
5 处理数据集
首先我们定义训练集与测试集,我们定义前20%为测试集,后80%为训练集
这里我们定义循环与乱序和批次,之后定义预读取
- 我们后面使用model.fit进行训练,在这里repeat()实际上没有必要
之后给测试集加上batch
6 创建模型
我们的模型结构是这样的
首先我们创建预训练网络 mobilenet,mobilenet是非常小的网络,所以非常适合部署在移动端
这里我们不使用其权重,只使用其网络结构
多分类问题中我们使用函数式API构建网络,函数式API之前在这篇文章中有提到 11.tf.keras函数式API_potato123232的博客-CSDN博客
我们这样进行搭建
上面的name我们后面编译的时候会用到
用这种方式我们可以使用get_shape()查看每一层的数据形状,比如我们现在看mobilenet的结果
我们看一下整体网络的情况
- 由于修改name的原因最后两层的名字不为dense_2和dense_3
7 编译模型
由于我们在这里两个输出都是多分类问题,我们可以使用同样的损失函数,所以我们可以这样写
如果我们的项目需要使用不同的损失函数,这个时候我们就用到了之前设置的name
8 训练模型
由于给了name所以训练时的指标变多了
我们看一下曲线
- color_acc
- clothes_acc
- color_loss
- clothes_loss
- 总loss
通过上面这四个曲线我们可以看到此时模型存在严重的过拟合现象,那么模型的结构是公认正确的,我们需要通过调小学习率与增大数据量以达到消除过拟合的情况,我在这里就不再训练了
之后我们保存模型
9 预测模型
首先导入库
我们在这里把打印出来的结果复制一下,之后进行键与值的换位
我们显示出来看一下
之后读取模型
再之后定义加载图像函数
之后我们定义预测函数
result是是一个列表,里面有两个array,分别是颜色和种类的预测结果,我们一会儿看一下就知道了,之前提到过将图像直接传入model中,就是model(img),这个和model.predict(img)的返回类型不同,但是返回结果的值是相同的
再之后我们找两张图,我们现在这个模型预测非数据集内的图片是没有意义的(val_acc太低,我们无法验证预测的过程是否正确),所以我们用训练数据来预测一下