- 接口RecordMetaData
RecordMetaData includes details on the record itself - for example, the source file or line number.
It is used in conjunction withorg.datavec.api.records.reader.RecordReaderMeta
.
There are two primary uses:
(a) Tracking where a record has come from, for debugging purposes for example
(b) Loading the raw data again later, from the record reader
RecordMetaData 包含记录详情,例如源文件、行号,与类
org.datavec.api.records.reader.RecordReaderMeta一起使用。他有两个主要的用途:
(a)跟踪数据来源,调试时使用;
(b)从记录读取器中加载原始数据
序号 | 方法 | 返回信息 | 备注 |
1 | getLocation | 获取可读的信息 | |
2 | getURI | 源文件的URI | |
接口DataSet
序号 | 方法 | 返回信息 | 备注 |
1 | getExampleMetaData() | Get the example metadata, or null if no metadata has been set | |
2 | getExampleMetaData(java.lang.Class<T> metaDataType) | Get the example metadata, or null if no metadata has been set Note: this method results in an unchecked cast - care should be taken when using this! | |
3 | getFeatures() | Returns the features array for the DataSet | |
4 | getLabels() | Returns the labels for the dataset |
接口Evaluation
Evaluation metrics:
- precision, recall, f1, fBeta, accuracy, Matthews correlation coefficient, gMeasure
- Top N accuracy (if using constructorEvaluation(List, int)
)
- Custom binary evaluation decision threshold (use constructorEvaluation(double)
(default if not set is argmax / 0.5)
- Custom cost array, usingEvaluation(INDArray)
orEvaluation(List, INDArray)
for multi-class
Note: Care should be taken when using the Evaluation class for binary classification metrics such as F1, precision, recall, etc. There are a number of cases to consider:
1. For binary classification (1 or 2 network outputs)
a) Default behaviour: class 1 is assumed as the positive class. Consequently, no-arg methods such asf1()
,precision()
,recall()
etc will report the binary metric for class 1 only
b) To set class 0 as the positive class instead of class 1 (the default), useEvaluation(int, Integer)
orEvaluation(double, Integer)
or#setBinaryPositiveClass(Integer)
. Then,f1()
,precision()
,recall()
etc will report the binary metric for class 0 only.
c) To use macro-averaged metrics over both classes for binary classification (uncommon and usually not advisable) specify 'null' as the argument (instead of 0 or 1) as per (b) above
2. For multi-class classification, binary metric methods such asf1()
,precision()
,recall()
will report macro-average (of the one-vs-all) binary metrics. Note that you can specify micro vs. macro averaging usingf1(EvaluationAveraging)
and similar methods
Note that setting a custom binary decision threshold is only possible for the binary case (1 or 2 outputs) and cannot be used if the number of classes exceeds 2. Predictions with probability > threshold are considered to be class 1, and are considered class 0 otherwise.
Cost arrays (a row vector, of size equal to the number of outputs) modify the evaluation process: instead of simply doing predictedClass = argMax(probabilities), we do predictedClass = argMax(cost * probabilities). Consequently, an array of all 1s (or, indeed any array of equal values) will result in the same performance as no cost array; non- equal values will bias the predictions for or against certain classes.
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());//Label the probabilities of the input
eval.eval(testData.getLabels(), output, testMetaData); //Note we are passing in the test set metadata here
log.info("\n{}\n",output);
log.info("统计报告为:"+eval.stats());//以字符串形式报告分类统计信息
List<Prediction> predictionErrors = eval.getPredictionErrors();
System.out.println("\n\n+++++ Prediction Errors +++++");
for(Prediction p : predictionErrors){
System.out.println("Predicted class: " + p.getPredictedClass() + ", Actual class: " + p.getActualClass()
+ "\t" + p.getRecordMetaData(RecordMetaData.class).getLocation());
}
List<Prediction> list1 = eval.getPredictions(1,2); //Predictions: actual class 1, predicted class 2
List<Prediction> list2 = eval.getPredictionByPredictedClass(2); //All predictions for predicted class 2
List<Prediction> list3 = eval.getPredictionsByActualClass(2); //All predictions for actual class 2
序号 | 方法 | 返回信息 | 备注 |
1 | Evaluation(int numClasses) | 构造函数 The number of classes to account for in the evaluation | |
2 | eval(INDArray realOutcomes, INDArray guesses, java.util.List<? extends java.io.Serializable> recordMetaData) | Evaluate the network, with optional metadata | |
3 | getPredictionErrors() | Get a list of prediction errors, on a per-record basis | |
4 | getPredictions(int actualClass, int predictedClass) | Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair) | |
5 | getPredictionByPredictedClass(int predictedClass) | Get a list of predictions, for all data with the specified predicted class, regardless of the actual data class. Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: | |
6 | getPredictionsByActualClass(int actualClass) | Get a list of predictions, for all data with the specified actual class, regardless of the predicted class. Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: | |
7 | stats() | Report the classification statistics as a String |
stats()方法返回结果示例
========================Evaluation Metrics 评估指标========================
# of classes: 3
Accuracy(精准度): 0.6038
Precision(准确率,查准率): 0.7083
Recall(召回率,查全率): 0.6845
F1 Score: 0.5852
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)
=========================Confusion Matrix 混淆矩阵也称误差矩阵=========================
0 1 2
----------
17 0 0 | 0 = 0
0 12 1 | 1 = 1
0 20 3 | 2 = 2Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================