MXNet框架如何自定义evaluation metric

本文介绍了在MXNet框架下如何自定义评价标准(Evaluation Metric),详细讲解了自定义metric的步骤,包括初始化、更新函数的实现,并通过代码示例展示了如何在训练过程中调用自定义的metric。此外,还分享了MXNet官方提供的metric类的使用方法和一些默认metric类的写法。
摘要由CSDN通过智能技术生成

使用深度学习框架训练模型的时候都需要用到评价标准,比如准确率等,那么在MXNet框架下,这些评价标准(Evaluation Metric)是怎么实现的呢?如果我们要自定义一个不一样的评价标准要怎么做?一起来了解下吧。

首先来看看在MXNet框架下关于evaluation metric的最基本的类和脚本。mxnet.metric.EvalMetric是MXNet框架中计算评价标准(evaluation metric)的基础类,这个基础类是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py中定义的。metric.py脚本中不仅包含类基础类的定义,还包含MXNet所有和metric相关的类的定义,所以如果你想更深入了解的话,可以看看这个脚本。

先来介绍下如何自定义一个evaluation metric。作者在文档中说过,这个类没有办法直接使用,当你要用的时候,应该要先自定义一个继承该基础类的类,将你想实现的评价标准写在里面,然后再在你的训练代码中调用该类。先看一个简单的定义一个evaluation metric的例子:

import mxnet as mx
class Accuracy(mx.metric.EvalMetric): # 在定义类名称的时候,括号里面表示继承哪个类
    def __init__(self, num=None):
        super(Accuracy, self).__init__('accuracy', num)

    def update(self, labels, preds):
        pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
        label = labels[0].asnumpy().astype('int32')

        mx.metric.check_label_shapes(label, pred_label)

        self.sum_metric += (pred_label.flat == label.flat).sum()
        self.num_inst += len(pred_label.flat)

在这个自定义类中主要包含一个__init__ 函数和一个update函数,前者是用来初始化,后者是用来更新metric,比如说你定义每计算一个batch的样本,就更新一次metric。updata函数的输入中,labels和preds都是一个NDArray的列表。在update函数中有个mx.metric.check_label_shaoes函数,这个函数的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py脚本中,是用来判断labels和preds的shape是否一致,因为labels和preds都是list,而且一般这个list中只包含一个NDArray,比如说你的batch size是16,类别数是1000,那么labels中的NDArray就是16*1,preds中的NDArray就是16*1000。

以上这些代码可以写在类似名字为my_metric.py脚本中,那么怎么调用了?在你的训练代码中,可以通过mxnet.metric.CompositeEvalMetric类来调用,这个类是用来管理你的evaluation metrics的,这个类的定义也是在MXNet项目的/incubator-mxnet/python/mxnet/metric.py脚本中。先来看看官方的关于使用这个类的例子&#

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值