DL4J中文文档/调优与训练/评估

为什么要评估?

当训练或部署神经网络时,了解模型的准确性是有用的。在DL4J中,评估类和评估类的变体可用于评估模型的性能。

分类评价

评估类用于评估二分类和多类分类器(包括时间序列分类器)的性能。本节介绍了评估类的基本用法。

给定一个DataSetIterator形式的数据集,执行评估的最简单方法是使用MultiLayerNetwork和ComutationGraph上的内置评估方法:

DataSetIterator myTestData = ...
Evaluation eval = model.evaluate(myTestData);

然而,也可以对单个小批量进行评价。这里是一个例子,从我们的示例项目中数据实例/CSV示例中获得。

CSV的例子有3类花的CSV数据,建立了一个简单的前馈神经网络用于对基于4个测量值的花的分类。

Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());

第一行创建一个具有3个类的评估对象。第二行从模型中获取我们测试数据集的标签。第三行使用eval方法将来自testdata的标签数组与从模型生成的标签进行比较。第四行将评估数据记录到控制台。

输出

Examples labeled as 0 classified by model as 0: 24 times
Examples labeled as 1 classified by model as 1: 11 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 2 classified by model as 2: 17 times


==========================Scores========================================
 # of classes:    3
 Accuracy:        0.9811
 Precision:       0.9815
 Recall:          0.9722
 F1 Score:        0.9760
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)
========================================================================

默认情况下,.stats() 方法显示混淆矩阵条目(每行一个)、准确度、精度、召回率和F1分数。此外,评估类还可以计算并返回以下值:

  • 混淆矩阵
  • 假阳性/阴性率
  • 真阳性/阴性
  • 类别计数
  • F-beta, G-measure, Matthews 关系数及更多, 查看 Evaluation JavaDoc

显示混淆矩阵。

System.out.println(eval.confusionToString());

显示

Predicted:         0      1      2
Actual:
0  0          |      16      0      0
1  1          |       0     19      0
2  2          |       0      0     18

此外,可以直接访问混淆矩阵,使用CSV或HTML转换。

eval.getConfusionMatrix() ;
eval.getConfusionMatrix().toHTML();
eval.getConfusionMatrix().toCSV();

回归评估

为了评估执行回归的网络,使用回归评估类。

带着评估类,一个DataSetIterator上的回归评估可以执行如下:

DataSetIterator myTestData = ...
RegressionEvaluation eval = model.evaluateRegression(myTestData);

这里有一个单列的代码片段,在这种情况下,神经网络是根据测量值来预测自己的年龄。

RegressionEvaluation eval =  new RegressionEvaluation(1);

打印评估的统计数据。

System.out.println(eval.stats());

返回

Column    MSE            MAE            RMSE           RSE            R^2            
col_0     7.98925e+00    2.00648e+00    2.82653e+00    5.01481e-01    7.25783e-01    

列是均方误差、均方绝对误差、均方根误差、相对平方误差和R^2决定系数。

查看 回归评估JavaDoc

同时进行多个评价

当执行多种类型的评估时(例如,在同一网络和数据集上执行评估和ROC),在数据集的一次传递中执行以下操作更有效:

DataSetIterator testData = ...
Evaluation eval = new Evaluation();
ROC roc = new ROC();
model.doEvaluation(testdata, eval, roc);

时间序列评估

时间序列评估与上述评估方法非常相似。DL4J中的评估对所有(非掩码的)时间步分别执行——例如,长度为10的时间序列将为评估对象贡献10个预测/标签。与时间序列的一个不同之处在于掩码数组是(可选的),这些掩码数组用于将一些时间步标记为丢失或不存在。请参阅使用RNNS掩码以获得更多关于掩码的细节。

对于大多数用户来说,仅仅使用 MultiLayerNetwork.evaluate(DataSetIterator) 或  MultiLayerNetwork.evaluateRegression(DataSetIterator) 和类似的方法就足够了。如果掩码数组存在,这些方法将正确地处理掩码。

二分类器评估

EvaluationBinary用于评估具有二分类输出的网络——这些网络通常具有Sigmoid激活函数和XENT损失函数。为每个输出计算典型的分类度量,例如准确度、精度、召回率、F1得分等。

EvaluationBinary eval = new EvaluationBinary(int size)

查看 EvaluationBinary JavaDoc

ROC

ROC(接收者操作特征)是另一种常用的评估分类器的评估指标。DL4J中存在三个ROC变体:

  • ROC -用“一对全部”的方法评估非二分类器
  • ROCBinary - 用于单二分类标签(作为单列概率,或两列的softmax概率分布)
  • ROCMultiClass - 用于多二分类标签

这些类具有通过calculateAUC()和calculateAUPRC()方法计算ROC曲线下面积(AUROC)和精确度-召回曲线下面积(AUPRC)的能力。此外,可以使用getRocCurve()getPrecisionRecallCurve()获得ROC和精确度-召回曲线。

ROC和精确度-召回曲线可以导出到HTML以便查看,使用:“EvaluationTools.exportRocChartsToHtmlFile(ROC,File)”,该文件将导出具有ROC和精确度-召回曲线的HTML文件,可以在浏览器中查看。

注意,所有三种支持两种操作/计算模式。

  • 阈值(近似AUROC/AUPRC计算,无内存问题)
  • 精确(精确的AUROC/AUPRC计算,但是对于非常大的数据集(即具有数百万个示例的数据集)可能需要大量的内存

可以使用构造函数设置容器的数量。可以使用默认构造函数new ROC()来精确设置,或者显式地使用new ROC(0)

参见ROCBinary JavaDoc用于评估二元分类器。

评估分类器校准

DL4J还具有评估校准类,它被设计用于分析分类器的校准。它提供了许多的工具用于如下目的:

  • 每个类别的标签数量和预测的计数
  • 可靠性图(或可靠性曲线)
  • 残差图(直方图)
  • 概率直方图,包括每个类的概率

使用评估校准的分类器评估方式与其它评估类相似。可以使用EvaluationTools.exportevaluationCalibrationToHtmlFile(EvaluationCalibration, File)将各种绘图/直方图导出到HTML以便查看。

Spark网络的分布式评估

SparkDl4jMultiLayer 和 SparkComputationGraph 都有相似的评估方法:

Evaluation eval = SparkDl4jMultiLayer.evaluate(JavaRDD<DataSet>);

//一次传递多次评价:
SparkDl4jMultiLayer.doEvaluation(JavaRDD<DataSet>, IEvaluation...);

多任务网络评估

多任务网络是经过训练以产生多个输出的网络。例如,可以对给定音频样本的网络进行训练,以预测说话者的语言和说话人的性别。这里简要描述了多任务配置。

适用于多任务网络的评估类

查看 ROCMultiClass JavaDoc

查看 ROCBinary JavaDoc

可用的评估


Evaluation

[源码]

评估指标:

  • 精度,召回率,F1,FBeta,准确度,马休斯相关系数,gMeasure
    argmax / 0.5)

     

     注意:在使用用于二分类度量(如F1、精确度、召回等)的评估类时应小心。有许多案例需要考虑:
    1. 对于二分类(1或2个网络输出)

      c)在两个类上使用宏平均度量进行二分类(不常见且通常不可取),如上(b)所示,指定“null”作为参数(而不是0或1)

      将报告宏平均(一个对全部)二分类度量。请注意,可以指定微vs宏平均

      注意,设置自定义二进制决策阈值仅对于二分类情况(1或2个输出)是可能的,并且如果类的数量超过2,则不能使用它。概率>阈值的预测被认为是类1,否则被认为是类0。


      成本数组(行向量,大小等于输出数量)修改评估过程:我们不是简单地执行predictedClass = argMax(probabilities),而是执行predictedClass = argMax(cost probabilities)。因此,所有1s的数组(或者实际上任何相等值的数组)将导致与无成本数组相同的性能;非相等值将偏离对某些类的预测。

Evaluation

public Evaluation(int numClasses) 

评估中要考虑的分类数

  • 参数 numClasses 评估中要考虑的分类数

Evaluation

public Evaluation(int numClasses, Integer binaryPositiveClass)

构造函数,用于指定类的数目,并且可选地用于二分类的正类。有关二分类情况下的评估的详细信息,请参见评估JavaDoc

  • 参数 numClasses 评估的分类数。必须是2,如果binaryPositiveClass是非空的
  • 参数 binaryPositiveClass 如果非空,则为正类(0或1)。

eval

public void eval(INDArray trueLabels, INDArray input, ComputationGraph network) 

对 使用给定的true标签的输出、计算图网络输入和用于评估的计算图网络 进行评估

  • 参数 trueLabels 使用的标签
  • 参数 input 用于评估的网络输入
  • 参数 network 用于输出的网络

eval

public void eval(INDArray trueLabels, INDArray input, MultiLayerNetwork network) 

对 使用给定的true标签的输出、多层网络输入和用于评估的多层网络 进行评估

  • 参数 trueLabels 使用的标签
  • 参数 input 用于评估的网络输入
  • 参数 network 用于输出的网络

eval

public void eval(INDArray realOutcomes, INDArray guesses) 

收集关于真实结果和猜测的统计数据。这是逻辑的结果矩阵。

请注意,如果传递的两个矩阵中长度不相同,则会抛出IllegalArgumentException。

  • 参数 realOutcomes 真实的结果(标签-通常是二分类的)
  • 参数 guesses 猜测/预测 (通常是概率向量)

eval

public void eval(final INDArray realOutcomes, final INDArray guesses,
                    final List<? extends Serializable> recordMetaData) 

用可选元数据评估网络

  • 参数 realOutcomes 数据标签
  • 参数 guesses 网络预测
  • 参数 recordMetaData 可选的;可以是空的。如果不是NULL,则其大小应该等于结果/猜测的数量。

eval

public void eval(int predictedIdx, int actualIdx) 

评估单一预测(一次一个预测)

  • 参数 predictedIdx 网络预测类索引
  • 参数 actualIdx 实际类索引

stats

public String stats() 

以字符串形式报告统计信息

  • 返回分类统计信息

stats

public String stats(boolean suppressWarnings) 

以字符串形式获取分类报告的方法。

  • 参数 suppressWarnings 是否输出与评估结果相关的警告
  • 返回(多行)字符串的准确性、精确性、召回、F1得分等

stats

public String stats(boolean suppressWarnings, boolean includeConfusion)

以字符串形式获取分类报告的方法。

  • 参数 suppressWarnings 是否输出与评估结果相关的警告
  • 参数 includeConfusion 混淆矩阵是否应包含在返回的统计数据中
  • 返回(多行)字符串的准确性、精确性、召回、F1得分等

confusionMatrix

public String confusionMatrix()

将混淆矩阵作为字符串获取

  • 作为字符串返回混淆矩阵

precision

public double precision(Integer classLabel) 

返回给定类标签的精度

  • 参数 classLabel 标签
  • 返回标签的精度

precision

public double precision(Integer classLabel, double edgeCase) 

返回给定类标签的精度

  • 参数 classLabel 标签
  • 参数 edgeCase 在0/0情况时的输出
  • 返回标签的精度

precision

public double precision() 

迄今为止,基于猜测的精确性。

注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均精度,相当于precision(EvaluationAveraging.Macro)。
  • 基于猜测返回总精度

precision

public double precision(EvaluationAveraging averaging) 

计算所有类的平均精度。可以指定是应该使用宏平均还是微观平均。注意:如果任何类具有tp=0和fp=0,(精度=0/0),则这些类被排除在平均值之外。

  • 参数 averaging 平均法-宏或微
  • 返回平均精度

averagePrecisionNumClassesExcluded

public int averagePrecisionNumClassesExcluded() 

在计算(宏)平均精度时,由于没有预测平均中排除了多少类——即,精度是0/0的边缘情况。

  • 返回从平均精度排除的类数

averageRecallNumClassesExcluded

public int averageRecallNumClassesExcluded() 

在计算(宏)平均召回时,由于没有预测平均中排除了多少类——即,召回是0/0的边缘情况。

  • 返回从平均召回排除的类数

averageF1NumClassesExcluded

public int averageF1NumClassesExcluded() 

在计算(宏)平均F1时,由于没有预测,从平均值中排除了多少类——即,F1将根据0/0的精度或召回率来计算。

  • 返回从平均F1排除的类数

averageFBetaNumClassesExcluded

public int averageFBetaNumClassesExcluded() 

在计算(宏)平均FBeta时,由于没有预测,从平均值中排除了多少类——即,FBeta将根据0/0的精度或召回率来计算。

  • 返回从平均FBeta排除的类数

recall

public double recall(int classLabel) 

返回给定标签的召回率

  • 参数 classLabel 标签
  • 返回double类型的召回率

recall

public double recall(int classLabel, double edgeCase) 

返回给定标签的召回率

  • 参数 classLabel 标签
  • 参数 edgeCase 在0/0的情况下的输出
  • 返回double类型的召回率

recall

public double recall() 

迄今为止基于猜测的召回

注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均召回,相当于recall(EvaluationAveraging.Macro)。
  • 为结果返回召回

recall

public double recall(EvaluationAveraging averaging) 

计算所有类的平均召回-可以指定是使用宏平均还是微观平均。注意:如果任何类都具有TP=0和fn=0,(召回=0/0),这些都是从平均值中排除的。

  • 参数 averaging 平均方法-宏或微
  • 返回平均召回率

falsePositiveRate

public double falsePositiveRate(int classLabel) 

返回给定标签的假阳性率

  • 参数 classLabel 标签
  • 返回double类型的假阳性率

falsePositiveRate

public double falsePositiveRate(int classLabel, double edgeCase) 

返回给定标签的假阳性率

  • 参数 classLabel 标签
  • 参数 edgeCase  0/0时的输出
  • 返回double类型的假阳性率

falsePositiveRate

public double falsePositiveRate() 

迄今为止基于猜测的假阳性率  注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均假阳性率,相当于falsePositiveRate(EvaluationAveraging.Macro)。
  • 返回输出假阳性率

falsePositiveRate

public double falsePositiveRate(EvaluationAveraging averaging) 

计算所有类别的平均假阳性率。可以指定是应该使用宏平均还是微观平均

  • 参数 averaging 平均方法.宏观或微观
  • 返回平均假阳性率

falseNegativeRate

public double falseNegativeRate(Integer classLabel) 

返回给定标签的假阴性率

  • 参数 classLabel 标签
  • 返回double类型的假阴性率

falseNegativeRate

public double falseNegativeRate(Integer classLabel, double edgeCase) 

返回给定标签的假阴性率

  • 参数 classLabel 标签
  • 参数 edgeCase 在0/0的情况下的输出
  • 返回double类型的假阴性率

falseNegativeRate

public double falseNegativeRate() 

迄今为止基于猜测的假阴性率  注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均假阴性率,相当于falseNegativeRate(EvaluationAveraging.Macro)。
  • 返回输出假阳性率

 

falseNegativeRate

public double falseNegativeRate(EvaluationAveraging averaging) 

计算所有类别的平均假阴性率。可以指定是应该使用宏平均还是微观平均

  • 参数 averaging 平均方法.宏观或微观
  • 返回平均假阴性率

falseAlarmRate

public double falseAlarmRate() 

误报率反映了对分类记录的错误分类率。http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw  注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均误报率。
  • 返回输出误报率

f1

public double f1(int classLabel) 

计算给定分类的F1分数

  • 参数 classLabel 计算F1的标签
  • 返回给定标签的F1分数

fBeta

public double fBeta(double beta, int classLabel) 

计算给定类的FBeta,其中FBeta定义为:

(1 +beta^ 2)(精确召回)/(beta^ 2精度+召回)。

F1是FBeta的一个特例,具有beta=1。

  • 参数 beta 使用的Beta值
  • 参数 classLabel 分类标签
  • 返回 FBeta

fBeta

public double fBeta(double beta, int classLabel, double defaultValue) 

计算给定类的FBeta,其中FBeta定义为:

(1 +beta^ 2)(精确召回)/(beta^ 2精度+召回)。

F1是FBeta的一个特例,具有beta=1。

  • 参数 beta 使用的Beta值
  • 参数 classLabel 分类标签
  • 参数 defaultValue 精度或召回未定义(精度或召回为0/0)时的缺省值
  • 返回 FBeta

f1

public double f1() 

计算F1得分

F1得分定义为:

TP:真阳性

FP:假阳性

FN:假阴性

F1得分:2 TP/(2TP+FP+FN)

注意:返回的值将根据类的数量和设置而不同。

  1. 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
  2. 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均 f1,相当于 f1(EvaluationAveraging.Macro)。
  • 返回基于当前猜测的f1分数或精度与召回的调和平均 

f1

public double f1(EvaluationAveraging averaging) 

计算所有类别的F1得分。可以指定是应该使用宏平均还是微观平均

  • 参数 averaging 平均方法.宏观或微观

fBeta

public double fBeta(double beta, EvaluationAveraging averaging) 

计算所有类别的F_beta得分。可以指定是应该使用宏平均还是微观平均

  • 参数 beta 使用的Beta值
  • 参数 averaging 平均方法.宏观或微观

 

gMeasure

public double gMeasure(int output) 

计算给定输出的G-measure

  • 参数 output 指定输出
  • 返回指定输出的G-measure

gMeasure

public double gMeasure(EvaluationAveraging averaging) 

使用微或宏平均计算所有输出的平均Gmeasure

  • 参数 averaging 平均方法.宏观或微观
  • 返回平均G measure

accuracy

public double accuracy() 

准确率: (TP + TN) / (P + N)

  • 返回到目前为止猜测的准确率

topNAccuracy

public double topNAccuracy() 

迄今为止预测的第N高的准确率。对于top n=1(默认值),相当于accuracy()

  • 返回 前N 准确率

matthewsCorrelation

public double matthewsCorrelation(int classIdx) 

为指定的类计算二进制马修斯相关系数。
MCC = (TPTN - FPFN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))

  • 参数 classIdx 计算马休斯相关系数的类指标

matthewsCorrelation

public double matthewsCorrelation(EvaluationAveraging averaging) 

计算平均二进制马修斯相关系数,使用宏观或微观平均。
MCC = (TPTN - FPFN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
注:这与多类马休斯相关系数不相同。

  • 参数 averaging 平均方法
  • 返回平均系数

incrementTruePositives

public void incrementTruePositives(Integer classLabel) 

真阳性:正确拒绝

  • 到目前为止返回全部真阳性

addToConfusion

public void addToConfusion(Integer real, Integer guess) 

添加到混淆矩阵

  • 参数 real 实际猜测
  • 参数 guess 系统猜测

classCount

public int classCount(Integer clazz) 

返回给定标签实际发生的次数。

  • 参数 clazz 标签
  • 返回标签实际发生的次数

getTopNCorrectCount

public int getTopNCorrectCount() 

根据前N个值返回正确预测的数目。对于前n=1(默认值),这相当于正确预测的数目。

  • 返回正确预测的前N数字

getTopNTotalCount

public int getTopNTotalCount() 

返回前N个评估的总数。大多数情况下,这完全等于getNumRowCounter(),但是在使用eval(int, int)的情况下可能不同,因为在这种情况下无法计算前N个精度(即,需要完整的概率分布,而不仅仅是预测/实际索引)。

  • 返回前N个预测的总数

 

getConfusionMatrix()

public ConfusionMatrix<Integer> getConfusionMatrix()

返回混淆矩阵变量

  • 返回此评估的混淆矩阵变量

confusionToString

public String confusionToString() 

获取混淆矩阵的字符串表示形式。

getPredictionErrors()

 public List<Prediction> getPredictionErrors() 

根据每个记录获取预测误差列表

注意:预测错误只能在“元数据评估”方法:eval(INDArray, INDArray, List)使用之后才能使用。否则(如果元数据没有通过前面提到的eval方法记录),则将每个预测输出分割成单独的Prediction对象是没有价值的,使用混淆矩阵获得计数,通过getConfusionMatrix()方法来替代。

  • 返回预测错误列表,或者如果没有记录元数据,则为null。

getPredictionsByActualClass(int actualClass) 

public List<Prediction> getPredictionsByActualClass(int actualClass) 

获取具有指定实际类的所有数据的预测列表,而不管预测类是什么。

注意:预测错误只能在“元数据评估”方法:eval(INDArray, INDArray, List)使用之后才能使用。否则(如果元数据没有通过前面提到的eval方法记录),则将每个预测输出分割成单独的Prediction对象是没有价值的,使用混淆矩阵获得计数,通过getConfusionMatrix()方法来替代。

  • 返回预测错误列表,或者如果没有记录元数据,则为null。
  • 参数 actualClass 用于预测的实际类

ROCBinary

[源码]

ROC (Receiver Operating Characteristic) 用于多任务二分类分类器

根据{@link ROC},ROCBinary支持exact(thersholdSteps==0)和阈值化;有关详细信息,请参见{@link ROC}。

与{@link ROC}不同(它支持单个二分类标签(作为单列概率或2列“softmax”概率)分布),ROCBinary假设所有输出都是独立的二分类变量。这也不同于{@link ROCMultiClass},它应该用于多类(单个非二分类)情况。ROCBinary支持每个示例和每个输出掩码:对于每个输出掩码,可能缺少任何特定的输出(掩码值0),因此不包括在计算的ROC中。

ROCBinary

public ROCBinary(int thresholdSteps) 
  • 参数 thresholdSteps 用于ROC计算的阈值步骤数。设置为0用于精确的ROC计算

ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)

public ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)
  • 参数 thresholdSteps 用于ROC计算的阈值步骤数。设置为0用于精确的ROC计算
  • 参数 rocRemoveRedundantPts 通常设置为true。如果为true,则从ROC和P R曲线中删除任何冗余点。

numLabels

public int numLabels() 

如果已知,返回标签的数目(即,预测/标签数组的大小)。否则返回1

getCountActualPositive

public long getCountActualPositive(int outputNum) 

获取指定输出/列的实际阳性统计(考虑任何掩码)

  • 参数 outputNum 输出索引 (0 到 {- link #numLabels()}-1)

getCountActualNegative

public long getCountActualNegative(int outputNum) 

获取指定输出/列的实际阴性统计(考虑任何掩码)

  • 参数 outputNum 输出索引 (0 到 {- link #numLabels()}-1)

getRocCurve

public RocCurve getRocCurve(int outputNum) 

获取指定输出的ROC曲线

  • 参数 outputNum 获取ROC曲线的输出数
  • 返回 ROC 曲线

getPrecisionRecallCurve

public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum) 

获取指定输出的精度召回曲线

  • 参数 outputNum 获得R—R曲线的输出数
  • 返回精确召回曲线

calculateAverageAuc

public double calculateAverageAuc() 

所有结果的宏观平均AUC

  • 返回所有结果的(宏观)平均AUC。

calculateAverageAUCPR

public double calculateAverageAUCPR()
  •  返回(宏)平均AUPRC(精确召回曲线下的面积)

calculateAUC

public double calculateAUC(int outputNum) 

计算(ROC)曲线下的AUC面积

在内部利用梯形积分

  • 参数 outputNum 输出数计算AUC
  • 返回 AUC

calculateAUCPR

public double calculateAUCPR(int outputNum) 

精度召回曲线下的AUCPR面积计算

在内部利用梯形积分

  • 参数 outputNum 计算AUCPR的输出数据
  • 返回 AUCPR

setLabelNames

public void setLabelNames(List<String> labels) 

设置标签名称,用于通过stats()打印


ConfusionMatrix

[源码]

创建一个空混淆矩阵

ConfusionMatrix

public ConfusionMatrix(ConfusionMatrix<T> other) 

创建用另一个ConfusionMatrix的内容初始化的新ConfusionMatrix。

add(T actual, T predicted)

public synchronized void add(T actual, T predicted)

增加实际指定的条目和一个预测条目。

toCSV

public String toCSV() 

将ConfusionMatrix输出为逗号分隔的值,以便于导入电子表格

toHTML

public String toHTML() 

在HTML表中输出混淆矩阵。级联样式表(CSS)可以通过定义empty-space, actual-count-header, predicted-class-header, and count-element 类来控制表的外观。例如

  • 返回 html 字符 

ROCMultiClass

[源码]

ROC(受试者工作特征曲线)用于多类分类器。

ROC曲线是通过将预测看作一组对全部的分类器来生成的,然后计算每个分类器的ROC曲线。在实践中,这意味着对于N类,我们得到N 条 ROC曲线。

ROCMultiClass

public ROCMultiClass(int thresholdSteps) 
  • 参数 thresholdSteps 计算的阈值步骤数。设置为0用于精确的ROC计算

ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)

 public ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts) 
  • 参数 thresholdSteps 用于ROC计算的阈值步骤数。如果设置为0:使用精确计算
  • 参数 rocRemoveRedundantPts 通常设置为true。如果为true,则从ROC和P R曲线中删除任何冗余点。

eval

public void eval(INDArray labels, INDArray predictions) 

对给定的小批量数据进行评估(收集统计数据)。对于时间序列(3维),使用{-link#evalTimeSeries(INDArray,INDArray)}或{-link#evalTimeSeries(INDArray,INDArray,INDArray)}

  • 参数 labels 标签/真实结果
  • 参数 predictions 预测

getRocCurve

public RocCurve getRocCurve(int classIdx) 

获取指定类的(一对全部)ROC曲线

  • 参数 classIdx 获取ROC曲线的类索引
  • 返回给定类的ROC曲线

getPrecisionRecallCurve

public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) 

获取指定类的(一对多)精确召回曲线

  • 参数 classIdx 获取P-R曲线的类索引
  • 返回给定类的精确召回曲线

calculateAUC

public double calculateAUC(int classIdx) 

计算ROC曲线下的AUC面积

在内部利用梯形积分

  • 返回 AUC

calculateAUCPR

public double calculateAUCPR(int classIdx) 

计算精度召回曲线下的AUPRC面积

在内部利用梯形积分

  • 返回 AUC

calculateAverageAUC

public double calculateAverageAUC() 

计算所有类的宏平均值(一对全部)AUC

calculateAverageAUCPR

public double calculateAverageAUCPR() 

计算所有类的宏平均(1对全部)AUCPR(精确召回曲线下的区域)

getCountActualPositive

public long getCountActualPositive(int outputNum) 

获取指定类的实际阳性统计(考虑任何掩码)

  • 参数 outputNum 类索引

getCountActualNegative

public long getCountActualNegative(int outputNum) 

获取指定输出/列的实际阴性统计(考虑任何掩码)

  • 参数 outputNum 类索引

merge

public void merge(ROCMultiClass other) 

将此ROCMultiClass实例与另一个实例合并。这个ROCMultiClass实例通过添加来自另一个实例的统计数据而被修改。

  • 参数 other 此结合的ROCMultiClass 实例

ROC

[源码]

ROC (受试者工作特征曲线)用于多类分类器。
ROC  有两种操作模式: (a) Thresholded (内存占用少)
(b) Exact (默认; 使用 numSteps == 0 来设置。可能不会扩展到非常大的数据集)

Thresholded(阈值化)是一种近似方法,(对于大型数据集)可能比exact(精确)模式使用更少的内存。尽管exact(精确)的实现将基于数据集自动计算阈值点,以给出“更平滑”和更精确的ROC曲线(或用于诊断目的的最佳切割点),阈值化使用大小为1.0的固定步骤/阈值步骤,因为这允许容易地实现批处理和分布式评估场景(其中完整数据集不能同时在任意一台机器上的内存中使用)。注意,在某些情况下(例如,非常歪斜的概率预测),阈值方法可能是不准确的,常常低估了真实区域。

假设数据是二分类-nColumns==1(单个二分类输出变量)或nColumns==2(在2个类上的概率分布,其中列1是“阳性”示例的值)

ROC

public ROC(int thresholdSteps) 
  • 能数 thresholdSteps 用于ROC计算的阈值步骤数。如果设置为0:使用精确计算

ROC(int thresholdSteps, boolean rocRemoveRedundantPts)

 public ROC(int thresholdSteps, boolean rocRemoveRedundantPts)
  • 参数 thresholdSteps 用于ROC计算的阈值步骤数。如果设置为0:使用精确计算
  • 参数 rocRemoveRedundantPts 通常设置为true。如果为true,则从ROC和P R曲线中删除任何冗余点。

eval

public void eval(INDArray labels, INDArray predictions) 

对给定的小批量数据进行评估(收集统计数据)。对于时间序列(3维),使用{-link#evalTimeSeries(INDArray,INDArray)}或{-link#evalTimeSeries(INDArray,INDArray,INDArray)}

  • 参数 labels 标签/真实结果
  • 参数 predictions 预测

 

getPrecisionRecallCurve

public PrecisionRecallCurve getPrecisionRecallCurve() 

以数组的形式获得精确召回曲线. return[0] = threshold array
return[1] = precision array
return[2] = recall array

  • return

getRocCurve

public RocCurve getRocCurve() 

获取ROC曲线,作为一组(阈值,假阳性,真阳性)点。

  • 返回 ROC 曲线

calculateAUC

public double calculateAUC() 

计算AUROC-ROC曲线下的面积

在内部利用梯形积分

  • 返回 AUC

calculateAUCPR

public double calculateAUCPR() 

计算精度/召回曲线下的面积

  • return

merge

public void merge(ROC other) 

将ROC实例与另一个实例合并。通过添加来自另一实例的STATS,修改此ROC实例。

  • param other ROC instance to combine with this one

IEvaluation

[源码]

用于评估神经网络的通用接口.方法由以下实现共享


EvaluationCalibration

[源码]

EvaluationCalibration是一个评估类,用于分析分类器的校准。它提供了许多用于此目的的工具:

  • 每个类别的标签数量和预测的统计
  • 可靠性图(或可靠性曲线)
  • 残差图(直方图)
  • 概率直方图,包括每个类的概率
    参考文献:
  • 可靠性图:参见例如Niculescu-Mizil和Caruana 2005,使用监督学习预测良好概率
  • 残差图:参见Wallace和Dahabreh 2012,类概率估计对于不平衡数据是不可靠的(以及如何修正它们)

EvaluationCalibration

public EvaluationCalibration() 

eval

public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) 

使用默认数量的容器创建一个评估校准实例

  • 参数 reliabilityDiagNumBins 可靠性图表的容器数(通常为10)
  • 参数 histogramNumBins 直方图的容器数

getReliabilityDiagram

public ReliabilityDiagram getReliabilityDiagram(int classIdx) 

获取指定类的可靠性图

  • 参数 classIdx 获得可靠性图的类的索引

getResidualPlotAllClasses

public Histogram getResidualPlotAllClasses() 
  • 返回每个类的观察标签的数量。对于N个类,被返回数组的长度为N,out[i]是类i的标签数

getResidualPlot

public Histogram getResidualPlot(int labelClassIdx) 

只为指定的类的例子获取残差图。残差图被定义为直方图。

|label_i-prob(class_i|input)|用于所有和示例;对于这个特定的方法,仅包括i==labelClassIdx的预测。

在一般情况下,相比大残差,小残差表示一个更优越的分类器。

  • 参数 labelClassIdx获取残差图的类的索引
  • 返回残差图(直方图)-所有预测/类

getProbabilityHistogramAllClasses

public Histogram getProbabilityHistogramAllClasses() 

返回所有预测/类的概率直方图。

  • 返回概率直方图

getProbabilityHistogram

public Histogram getProbabilityHistogram(int labelClassIdx) 

返回指定的标签类索引的概率直方图。也就是说,对于标签类索引i,返回P(class_i|input)的直方图,仅用于标记为类i的那些示例。

  • 参数  labelClassIdx 标签类的索引以获得直方图
  • 返回概率直方图

EvaluationUtils

[源码]

评估方法工具

precision

public static double precision(long tpCount, long fpCount, double edgeCase) 

计算真阳性和假阳性统计的精确度

  • 参数 tpCount 真阳性统计
  • 参数 fpCount 假阳性统计
  • 参数 edgeCase 边缘案例值避免使用0/0
  • 返回 精确度

recall

public static double recall(long tpCount, long fnCount, double edgeCase) 

计算真阳性和假阴性统计的召回率

  • 参数 tpCount 真阳性统计
  • 参数 fnCount 假阴性统计
  • 参数 edgeCase 边缘案例值避免使用0/0
  • 返回召回率

falsePositiveRate

public static double falsePositiveRate(long fpCount, long tnCount, double edgeCase) 

根据假阳性统计和真阴性统计计算假阳性率

  • 参数 fpCount 假阳性统计
  • 参数 tnCount 真阴性统计
  • 参数 edgeCase 边缘案例值避免使用0/0
  • 返回假阳性率

falseNegativeRate

public static double falseNegativeRate(long fnCount, long tpCount, double edgeCase) 

根据假阴性统计和真阳性统计计算假阴性率

  • 参数 fnCount 假阴性统计
  • 参数 tpCount 真阳性统计
  • 参数 edgeCase 边缘案例值避免使用0/0
  • 返回假阴性率

fBeta

public static double fBeta(double beta, long tp, long fp, long fn) 

从统计计算F-beta值

  • 参数 beta 使用的Beta值
  • 参数 tp 真阳性统计
  • 参数 fp 假阳性统计
  • 参数 fn 假阴性统计
  • 返回F beta

fBeta

public static double fBeta(double beta, double precision, double recall) 

从精确率与召回计算F-beta值

  • 参数 beta 使用的Beta值
  • 参数 precision 精确率
  • 参数 recall 召回率
  • 返回 F-beta 值

gMeasure

public static double gMeasure(double precision, double recall) 

从精确率与召回计算 G-measure值

  • 参数 precision 精确率
  • 参数 recall 召回率
  • 返回 G-measure

matthewsCorrelation

public static double matthewsCorrelation(long tp, long fp, long fn, long tn) 

从统计计算二分类马休斯相关系数

  • 参数 tp 真阳统计
  • 参数 fp 假阳统计
  • 参数 fn 假阴统计
  • 参数 tn 真阴统计
  • 返回马休斯相关系数

EvaluationBinary

[源码]

EvaluationBinary: 用于评估具有二分类输出的网络。对每个输出计算典型的分类度量,如准确度、精确度、召回率、F1评分等。

注意,EvaluationBinary支持每个示例和每个输出掩码。
默认情况下,EvaluationBinary使用0.5的决策阈值,但是可以对每个输出设置决策阈值

EvaluationBinary

public EvaluationBinary(INDArray decisionThreshold) 

用可选的决策阈值数组创建EvaulationBinary实例。

  • 参数 decisionThreshold 每个输出的判定阈值,可以是空的。应该是长度等于输出数量的行向量,其值在0到1的范围内。0.5值的数组等效于默认值(没有手动指定的决策阈值)。

eval

public void eval(INDArray labels, INDArray networkPredictions) 

当rocBinarySteps参数非空时,此构造器允许除了标准评估度量之外还计算ROC。详情请参阅 {- link ROCBinary}

  • 参数 size 输出数量
  • 参数 rocBinarySteps 用于  {- link ROCBinary#ROCBinary(int)}的构造器

numLabels

public int numLabels() 

如果已知,返回标签的数目(即,预测/标签阵列的大小)。否则返回-1

setLabelNames

public void setLabelNames(List<String> labels) 

设置标签名称,用于通过{- link #stats()}打印

totalCount

public int totalCount(int outputNum) 

获取指定列的值总数,考虑任何掩码

truePositives

public int truePositives(int outputNum) 

获取指定输出的真阳统计

trueNegatives

public int trueNegatives(int outputNum) 

获取指定输出的真阴统计

falsePositives

public int falsePositives(int outputNum) 

获取指定输出的假阳统计

falseNegatives

public int falseNegatives(int outputNum) 

获取指定输出的假阴统计

accuracy

public double accuracy(int outputNum) 

获取指定输出的准确率

precision

public double precision(int outputNum) 

获取指定输出的精度(tp / (tp + fp))

recall

public double recall(int outputNum) 

获取指定输出的召回率(tp/(tp+fn))

fBeta

public double fBeta(double beta, int outputNum) 

获取指定输出的 F-beta 值

  • 参数 beta 使用的beta值
  • 参数 outputNum 输出数
  • 返回指定输出的 F-beta

f1

public double f1(int outputNum) 

获取指定输出的F1 得分

matthewsCorrelation

public double matthewsCorrelation(int outputNum) 

计算指定输出的马休斯相关系数

  • 参数 outputNum 输出数字
  • 返回马休斯相关系数

gMeasure

public double gMeasure(int output) 

计算指定输出的 G-measure

  • 参数 output 指定输出
  • 返回指定输出的 G-measure
public double falsePositiveRate(int classLabel) 

返回给定标签的假阳性率

  • 参数 classLabel 标签
  • 返回double类型的假阳性率 

falsePositiveRate

public double falsePositiveRate(int classLabel, double edgeCase) 

返回给定标签的假阳性率

  • 参数 classLabel 标签
  • 参数edgeCase 0/0情况下的输出
  • 返回double类型的假阳性率 

 

falseNegativeRate

public double falseNegativeRate(Integer classLabel) 

返回给定标签的假阴性率

  • 参数 classLabel 标签
  • 返回double类型的假阴性率 

 

falseNegativeRate

public double falseNegativeRate(Integer classLabel, double edgeCase) 

返回给定标签的假阴性率

  • 参数 classLabel 标签
  • 参数edgeCase 0/0情况下的输出
  • 返回double类型的假阴性率 

getROCBinary

public ROCBinary getROCBinary() 

返回ROCBinary实例,如果存在

stats

public String stats() 

使用默认精度获取EvaluationBinary类的String表示

stats

public String stats(int printPrecision) 

使用默认精度获取EvaluationBinary类的String表示

  • 参数 printPrecision 精度(小数位数)用于准确率、F1等。

RegressionEvaluation

[源码]

评估回归算法的评估方法。
为每个列提供以下度量:

RegressionEvaluation

public RegressionEvaluation() 
  • 如果度量应该最小化,则返回True;如果度量应该最大化,则返回false。例如,MSE为0是最好的,但是R^ 2为1是最好的。

correlationR2

public double correlationR2(int column) 

相关评分的遗留方法。

  • 参数 column要评估的列
  • 返回给定列的皮尔森相关系数 
  • 查看{- link #pearsonCorrelation(int)}
  • 弃用,用{- link #pearsonCorrelation(int)} 替代。对于R2得分的使用 {- link #rSquared(int)}.

pearsonCorrelation

public double pearsonCorrelation(int column) 

样本的皮尔森相关系数

  • 参数 column 要评估的列
  • 返回具有索引{代码列}的列皮尔森相关系数
  • 查看 Wikipedia

rSquared

public double rSquared(int column) 

决定系数 (R^2 Score)

  • 参数 column 要评估的列
  • 返回具有索引{代码列}的列的R^ 2分数
  • 查看 Wikipedia

averageMeanSquaredError

public double averageMeanSquaredError() 

所有列的平均MSE

  • return

averageMeanAbsoluteError

public double averageMeanAbsoluteError() 

所有列的平均MAE

  • return

averagerootMeanSquaredError

public double averagerootMeanSquaredError() 

所有列的平均RMSE

  • return

averagerelativeSquaredError

public double averagerelativeSquaredError() 

所有列的平均RSE

  • return

averagecorrelationR2

public double averagecorrelationR2() 

用于所有列的相关平均值的遗留方法。

  • 在所有列上平均的返回皮尔森相关
  • 查看 {- link #averagePearsonCorrelation()}
  • 弃用,用{- link #averagePearsonCorrelation()} 替代。对于R2得分的使用 {- link #averageRSquared()}.

averagePearsonCorrelation

public double averagePearsonCorrelation() 

所有列的平均皮尔森相关系数

  • 所有列的平均皮尔森相关系数

averageRSquared

public double averageRSquared() 

所有列的平均R2

  • 返回所有列的平均R2

EvaluationAveraging

[源码]

当应用于多类分类问题时,二分类评估度量的平均化方法。宏平均:每个类的权重相等
微平均:每个例子权重相等
一般来说,对于不平衡数据集来说,宏平均是首选的。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值