mxnet常用的数据评估指标

mxnet最近更新很多文档,其中就包括了常用的数据评估指标。相关文档参考:http://mxnet.io/api/python/metric.html#overview

仔细看了一下,大部分都是分类用的评估指标,线性回归的很少。我猜可能是因为mxnet做线性回归不太行的缘故。下面把常用的指标列出来,以备查看:

#encoding=utf-8
'''
测试用的校验指标. 包括准确率,F1指标、topk和Perplexity
http://mxnet.io/api/python/metric.html#overview
'''
import mxnet as mx
import numpy as np
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]

#1、最常用的准确率Accuracy
eval_metrics_1 = mx.metric.Accuracy()
#2、分类的综合评估指标F1. This F1 score only supports binary classification
eval_metrics_2 = mx.metric.F1()
eval_metrics = mx.metric.CompositeEvalMetric()
for child_metric in [eval_metrics_1, eval_metrics_2]:
    eval_metrics.add(child_metric)

eval_metrics.update(labels = labels, preds = predicts)
print eval_metrics.get() #(['accuracy', 'f1'], [0.6666666666666666, 0.8])

#3、平均绝对误差MAE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_absolute_error = mx.metric.MAE()
mean_absolute_error.update(labels = labels, preds = predicts)
print mean_absolute_error.get() #('mae', 0.5)

#4、均方差MSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
mean_squared_error = mx.metric.MSE()
mean_squared_error.update(labels = labels, preds = predicts)
print mean_squared_error.get()#('mse', 0.375)

#5、标准差RMSE
predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
root_mean_squared_error = mx.metric.RMSE()
root_mean_squared_error.update(labels = labels, preds = predicts)
print root_mean_squared_error.get()#('rmse', 0.61237245798110962)

#6、交叉熵CrossEntropy
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
ce = mx.metric.CrossEntropy()
ce.update(labels, predicts)#('cross-entropy', 0.57159948348999023)
print ce.get()

#7、前K项指标,top k指标。k越大,值越大,因为包含的可能性更高。
np.random.seed(999)
top_k = 3 #前3项指标
labels = [mx.nd.array([2, 6, 9, 2, 3, 4, 7, 8, 9, 6])]
predicts = [mx.nd.array(np.random.rand(10, 10))]
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_3', 0.3)

top_k = 5 #前5项指标
acc = mx.metric.TopKAccuracy(top_k=top_k)
acc.update(labels, predicts)
print acc.get() # ('top_k_accuracy_5', 0.6)

'''
8、Perplexity指标。
Perplexity is a measurement of how well a probability distribution or model predicts a sample.
A low perplexity indicates the model is good at predicting the sample.
简单来说,perplexity就是对于语言模型所估计的一句话出现的概率.Perplexity其实表示的是average branch factor,
大概可以翻译为平均分支系数。即平均来说,我们预测下一个词时有多少种选择。
摘录自:http://blog.csdn.net/luo123n/article/details/48902815
'''
predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
labels   = [mx.nd.array([0, 1, 1])]
perp = mx.metric.Perplexity(ignore_label=None)
perp.update(labels, predicts)
print perp.get() #('Perplexity', 1.7710976285155853)

转载于:https://my.oschina.net/qinhui99/blog/994789

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值