使用Keras和tensorflow2.2可以无缝地为深度神经网络训练添加复杂的指标
Keras对基于DNN的机器学习进行了大量简化,并不断改进。这里,我们将展示如何基于混淆矩阵(召回、精度和f1)实现度量,并展示如何在tensorflow 2.2中非常简单地使用它们。
当考虑一个多类问题时,人们常说,如果类是不平衡的,那么准确性就不是一个好的度量标准。虽然这是肯定的,但是当所有的类新练的不完全拟合时,即使数据集是平衡的,准确性也是一个糟糕的度量标准。
在本文中,我将使用Fashion MNIST来进行说明。然而,这并不是本文的唯一目标,因为这可以通过在训练结束时简单地在验证集上绘制混淆矩阵来实现。我们在这里讨论的是轻松扩展keras.metrics的能力。用来在训练期间跟踪混淆矩阵的度量,可以用来跟踪类的特定召回、精度和f1,并使用keras按照通常的方式绘制它们。
在训练中获得班级特定的召回、精度和f1至少对两件事有用:
我们可以看到训练是否稳定,每个类的损失在图表中显示的时候没有跳跃太多
我们可以使用一些技巧-早期停止甚至动态改变类权值。
自tensorflow 2.2以来,添加了新的模型方法trainstep和teststep,将这些定制度量集成到训练和验证中变得非常容易。还有一个关联predict_step,我们在这里没有使用它,但它的工作原理是一样的。
我们首先创建一个自定义度量类。虽然还有更多的步骤,它们在参考的jupyter笔记本中有所体现,但重要的是实现API并与Keras 训练和测试工作流程的其余部分集成在一起。这就像实现和updatestate一样简单,updatestate接受真实的标签和预测,reset_state