我们对上一章的代码进行更改,我们简单看一下之前的代码,首先整理路径
之后创建数据集
- 与之前略有区别,train数据集不要加repeat(),否则会一直循环下去,以至于不能完成一个epoch
再之后创建网络
下面是自定义训练中与前面代码不同的地方
目录
1 定义优化器
之后我们定义优化器,这次我们减小一点学习速率
但是效果有限,因为数据还是这么多
2 定义损失函数
之后我们定义loss
- color_label 颜色实际值
- color_predict 颜色预测值
- clothes_label 种类实际值
- clothes_predict 种类预测值
现在我们两个输出都是使用sparse_categorical_crossentropy,如果两个输出要使用不同的损失函数就可以换掉它
当然这里也可以分开定义,分开定义后面就各自使用各自的,我这里就仅仅定义了总loss
3 定义指标
之后我们定义训练指标,loss我们取均值,正确率我们使用分类问题数字编码的特有正确率SparseCategoricalAccuracy
我们在这里看一下 SparseCategoricalAccuracy(),这个函数也能自己用,我们现在让(0.02, 0.83,0.86, 0.1)为预测值,让1为真实值,我们可以看到预测值经过处理后为2(0.86最大),预测值为2,实际值为1,那么就代表预测错了,所以返回0.0
我们现在把实际值改为2
此时预测值为2,实际值为2,那就代表预测对了,所以acc的结果为1
在之前猫狗的时候需要转换的时候我们是直接用的Accuracy,两个数相同就为1,不同就为0
下面是相同的情况
4 定义训练步骤
5 定义测试步骤
6 定义指标列表
7 定义训练过程
训练过程
测试过程
训练的时候我们要把对应关系记录下来
我们把图表画出来看一下
- color_acc
- clothes_acc
- loss
我们这次减少了学习率,发现准确率都在0.9以上,这个模型还算可以
之后我们保存模型
8 预测模型
由于我们这个模型还算可以,所以我们找数据集外的图测试一下
这个预测没有问题,那么我们现在想测试一下,能不能识别黑短袖(black_shirt),这个在标签中存在,但是在数据集中并没有黑短袖的分类
我们发现是识别不出来的,而且颜色的识别结果中黑色相对于蓝色的大小是非常低的,这个测试告诉我们如果机器没有见过黑短袖,只在其他衣服中见过黑色与短袖,这样是很难识别出来的