使用深度学习框架训练模型的时候都需要用到评价标准,比如准确率等,那么在MXNet框架下,这些评价标准(Evaluation Metric)是怎么实现的呢?如果我们要自定义一个不一样的评价标准要怎么做?一起来了解下吧。
首先来看看在MXNet框架下关于evaluation metric的最基本的类和脚本。mxnet.metric.EvalMetric
是MXNet框架中计算评价标准(evaluation metric)的基础类,这个基础类是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py
中定义的。metric.py
脚本中不仅包含类基础类的定义,还包含MXNet所有和metric相关的类的定义,所以如果你想更深入了解的话,可以看看这个脚本。
先来介绍下如何自定义一个evaluation metric。作者在文档中说过,这个类没有办法直接使用,当你要用的时候,应该要先自定义一个继承该基础类的类,将你想实现的评价标准写在里面,然后再在你的训练代码中调用该类。先看一个简单的定义一个evaluation metric的例子:
import mxnet as mx
class Accuracy(mx.metric.EvalMetric): # 在定义类名称的时候,括号里面表示继承哪个类
def __init__(self, num=None):
super(Accuracy, self).__init__('accuracy', num)
def update(self, labels, preds):
pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
label = labels[0].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
在这个自定义类中主要包含一个__init__
函数和一个update
函数,前者是用来初始化,后者是用来更新metric,比如说你定义每计算一个batch的样本,就更新一次metric。updata函数的输入中,labels和preds都是一个NDArray的列表。在update函数中有个mx.metric.check_label_shaoes
函数,这个函数的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py
脚本中,是用来判断labels和preds的shape是否一致,因为labels和preds都是list,而且一般这个list中只包含一个NDArray,比如说你的batch size是16,类别数是1000,那么labels中的NDArray就是16*1,preds中的NDArray就是16*1000。
以上这些代码可以写在类似名字为my_metric.py
脚本中,那么怎么调用了?在你的训练代码中,可以通过mxnet.metric.CompositeEvalMetric
类来调用,这个类是用来管理你的evaluation metrics的,这个类的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py
脚本中。先来看看官方的关于使用这个类的例子&#