mxnet最近更新很多文档,其中就包括了常用的数据评估指标。相关文档参考:http://mxnet.io/api/python/metric.html#overview
仔细看了一下,大部分都是分类用的评估指标,线性回归的很少。我猜可能是因为mxnet做线性回归不太行的缘故。下面把常用的指标列出来,以备查看:
#encoding=utf-8 ''' 测试用的校验指标. 包括准确率,F1指标、topk和Perplexity http://mxnet.io/api/python/metric.html#overview ''' import mxnet as mx import numpy as np predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])] labels = [mx.nd.array([0, 1, 1])] #1、最常用的准确率Accuracy eval_metrics_1 = mx.metric.Accuracy() #2、分类的综合评估指标F1. This F1 score only supports binary classification eval_metrics_2 = mx.metric.F1() eval_metrics = mx.metric.CompositeEvalMetric() for child_metric in [eval_metrics_1, eval_metrics_2]: eval_metrics.add(child_metric) eval_metrics.update(labels = labels, preds = predicts) print eval_metrics.get() #(['accuracy', 'f1'], [0.6666666666666666, 0.8]) #3、平均绝对误差MAE predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))] labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))] mean_absolute_error = mx.metric.MAE() mean_absolute_error.update(labels = labels, preds = predicts) print mean_absolute_error.get() #('mae', 0.5) #4、均方差MSE predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))] labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))] mean_squared_error = mx.metric.MSE() mean_squared_error.update(labels = labels, preds = predicts) print mean_squared_error.get()#('mse', 0.375) #5、标准差RMSE predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))] labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))] root_mean_squared_error = mx.metric.RMSE() root_mean_squared_error.update(labels = labels, preds = predicts) print root_mean_squared_error.get()#('rmse', 0.61237245798110962) #6、交叉熵CrossEntropy predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])] labels = [mx.nd.array([0, 1, 1])] ce = mx.metric.CrossEntropy() ce.update(labels, predicts)#('cross-entropy', 0.57159948348999023) print ce.get() #7、前K项指标,top k指标。k越大,值越大,因为包含的可能性更高。 np.random.seed(999) top_k = 3 #前3项指标 labels = [mx.nd.array([2, 6, 9, 2, 3, 4, 7, 8, 9, 6])] predicts = [mx.nd.array(np.random.rand(10, 10))] acc = mx.metric.TopKAccuracy(top_k=top_k) acc.update(labels, predicts) print acc.get() # ('top_k_accuracy_3', 0.3) top_k = 5 #前5项指标 acc = mx.metric.TopKAccuracy(top_k=top_k) acc.update(labels, predicts) print acc.get() # ('top_k_accuracy_5', 0.6) ''' 8、Perplexity指标。 Perplexity is a measurement of how well a probability distribution or model predicts a sample. A low perplexity indicates the model is good at predicting the sample. 简单来说,perplexity就是对于语言模型所估计的一句话出现的概率.Perplexity其实表示的是average branch factor, 大概可以翻译为平均分支系数。即平均来说,我们预测下一个词时有多少种选择。 摘录自:http://blog.csdn.net/luo123n/article/details/48902815 ''' predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])] labels = [mx.nd.array([0, 1, 1])] perp = mx.metric.Perplexity(ignore_label=None) perp.update(labels, predicts) print perp.get() #('Perplexity', 1.7710976285155853)