1、评价标准(Evaluation Metric)
使用深度学习框架训练模型的时候都需要用到评价标准,比如准确率等,那么在MXNet框架下,这些评价标准(Evaluation Metric)是怎么实现的呢?如果我们要自定义一个不一样的评价标准要怎么做?一起来了解下吧,首先来看看在MXNet框架下关于evaluation metric的最基本的类和脚本。
1.1 、用mxnet.metric.create(metric, *args, **kwargs)创建自己的评估标准
这种方法不太常用
1.2、通过继承mx.metric.EvalMetric()类添加自己的损失函数和评估验证函数
mxnet.metric.EvalMetric
是MXNet框架中计算评价标准(evaluation metric)的基础类,这个基础类是在MXNet项目的/mxnet/metric.py
中定义的。metric.py脚本中不仅包含计算评价标准的基础类,还包含MXNet所有和metric相关的类的定义,所以如果你想更深入了解的话,可以看看这个脚本。
2、自定义一个Evaluation Metric
先来介绍下如何自定义一个evaluation metric。作者在文档中说过,这个类没有办法直接使用,当你要用的时候,应该要先自定义一个继承该基础类的类,将你想实现的评价标准写在里面,然后再在你的训练代码中调用该类。先看一个简单的定义一个evaluation metric的例子:
import mxnet as mx
import numpy as np
class Accuracy(mx.metric.EvalMetric): # 在定义类名称的时候,括号里面表示继承哪个类
def __init__(self, num=None):
super(Accuracy, self).__init__('accuracy', num)
def update(self, labels, preds):
# 因为predicts有三个维度[mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])],则preds[0]=mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
# mx.nd.argmax_channel() 返回最大值所在的维度
pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
# print(pred_label) # [1 1 1]
label = labels[0].asnumpy().astype('int32') # labels=[mx.nd.array([0, 1, 1])],labels[0]=mx.nd.array([0, 1, 1])
# print(label) # [0 1 1]
mx.metric.check_label_shapes(label, pred_label) # 检查pred_label与labe的shape是否完全一样
# print(np.array(label).shape) # (3,)
# print(np.array(pred_label).shape) # (3,)
self.sum_metric += (pred_label.flat == label.flat).sum()
# print(self.sum_metric) # 2.0
self.num_inst += len(pred_label.flat)
# print(self.num_inst) # 3
在这个自定义类中主要包含一个__init__ ()函数
和一个update()函数
,前者是用来初始化,后者是用来更新metric,比如说你定义每计算一个batch的样本,就更新一次metric。updata函数的输入中,labels和preds都是一个NDArray的列表(形如[[0. 1. 1.] <NDArray 3 @cpu(0)>])。在update函数中有个mx.metric.check_label_shapes()函数
,这个函数的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py脚本中,是用来判断labels和preds的shape是否完全一致,因为labels和preds都是list,而且一般这个list中只包含一个NDArray,比如说你的batch size是16,类别数是1000,那么labels中的NDArray就是16×1,preds中的NDArray就是16×1000。
以上这些代码可以写在类似名字为my_metric.py
脚本中,那么怎么调用了?在你的训练代码中,可以通过mxnet.metric.CompositeEvalMetric类
来调用,这个类是用来管理你的evaluation metrics的,这个类的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py
脚本中。
2.1、调用自定义的Accuracy类:
调用上方的Accuracy()类
import mxnet as mx
import numpy as np
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
# print(predicts)
# [
# [[0.3 0.7]
# [0. 1. ]
# [0.4 0.6]]
# <NDArray 3x2 @cpu(0)>]
labels = [mx.nd.array([0, 1, 1])]
# print(labels)
# [
# [0. 1. 1.]
# <NDArray 3 @cpu(0)>]
eval_metrics_1 = Accuracy()
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(eval_metrics_1)
eval_metrics.update(labels=labels, preds=predicts)
print(eval_metrics.get())
# (['accuracy'], [0.6666666666666666])
当你要调用自定义的metric时,主要用到这个CompositeEvalMetric类的add()方法,用一次add方法就能够增加一种评价方式,从而在训练界面显示,而最后的update方法是手动更新计算结果,一般我们仅需下面两行定义一个eval_metric即可,然后把eval_metric作为训练的fit函数的一个参数即可:
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(Accuracy())
在代码运行到要更新eval_metric的时候,更具体的是base_model.py
中的self.metric(eval_metric, data_batch.label)
这一行,即:
然后,会先调用mxnet.metric.CompositeEvalMetric()类的update_dict()函数
,再调用基类mxnet.metric.EvalMetric的update_dict()函数
2.2、调用官方的Accuracy类
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels = [mx.nd.array([0, 1, 1])]
eval_metrics_1 = mx.metric.Accuracy() # 不同点
eval_metrics_2 = mx.metric.F1()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metrics_1, eval_metrics_2]:
eval_metrics.add(child_metric)
eval_metrics.update(labels = labels, preds = predicts)
print(eval_metrics.get())
# (['accuracy', 'f1'], [0.6666666666666666, 0.8])
这里的mx.metric.Accuracy()
和mx.metric.F1()
就是官方在/mxnet/metric.py
中定义的mx.metric.EvalMetric类
的实现
3、官方提供常见的评价标准类
上面说完了怎么自定义一个evaluation metric类以及如何调用自定义的metric类,那么下面就看看官方的/mxnet/metric.py脚本中默认的一些metric类(用于评价度量的类,如上图)的写法是什么样的吧。非常喜欢MXNet代码的一个原因就是其代码注释和文档做得很好,看起来一目了然。
class Accuracy(EvalMetric):
# 这里也是继承的mxnet.metric.EvalMetric这个类,只不过因为这两个类在同一个脚本中定义,因此就可以直接写EvalMetric
"""Computes accuracy classification score.
The accuracy score is defined as
.. math::
\\text{accuracy}(y, \\hat{y}) = \\frac{1}{n} \\sum_{i=0}^{n-1}
\\text{1}(\\hat{y_i} == y_i)
Parameters
----------
axis : int, default=1
The axis that represents classes
name : str
Name of this metric instance for display.
output_names : list of str, or None
Name of predictions that should be used when updating with update_dict.
By default include all predictions.
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
Examples
--------
>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels = [mx.nd.array([0, 1, 1])]
>>> acc = mx.metric.Accuracy()
>>> acc.update(preds = predicts, labels = labels)
>>> print acc.get()
('accuracy', 0.6666666666666666)
"""
def __init__(self, axis=1, name='accuracy',output_names=None, label_names=None):
# super这个函数是调用基类mx.metric.EvalMetric的__init__函数,
# __init__函数括号中的变量是要传递给基类的__init__函数的变量。
# super()括号中的Accuracy表示类名称。
super(Accuracy, self).__init__(
name, axis=axis,
output_names=output_names, label_names=label_names)
self.axis = axis
def update(self, labels, preds):
"""Updates the internal evaluation result.
Parameters
----------
labels : list of `NDArray`
The labels of the data.
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
# zip函数可以将输入的两个list的对应位置的值变成一个元组(tuple),这样每个tuple就包含两个值,
# 这两个值在这里都是NDArray格式。又因为pred_label的shape和label的shape是不一样的,
# 所以都会进入下面这个if语句,也就是先将pred_label按行求出最大值所在的index,
# 然后pred_label就和label是相同shape的NDArray了。
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
# 先用asnumpy()方法将NDArray转换成numpy.ndarray,然后把数值转成32位int型,原来是浮点型
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
check_label_shapes(label, pred_label)
# (pred_label.flat == label.flat) 会返回numpy.ndarray格式,内容是false或者true,
# 表示相等或不相等,最后求一个和
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
参考1:https://blog.csdn.net/u014380165/article/details/78311231
参考2:https://blog.csdn.net/fuwenyan/article/details/79902002