dl4j之RecordMetaData、DataSet、Evaluation类说明

  • 接口RecordMetaData

RecordMetaData includes details on the record itself - for example, the source file or line number.
It is used in conjunction with org.datavec.api.records.reader.RecordReaderMeta.
There are two primary uses:
(a) Tracking where a record has come from, for debugging purposes for example
(b) Loading the raw data again later, from the record reader

RecordMetaData 包含记录详情,例如源文件、行号,与类org.datavec.api.records.reader.RecordReaderMeta一起使用。他有两个主要的用途:

(a)跟踪数据来源,调试时使用;

(b)从记录读取器中加载原始数据

序号方法返回信息备注
1
getLocation
获取可读的信息 
2
getURI
源文件的URI  
    

接口DataSet

序号方法返回信息备注
1
getExampleMetaData()
Get the example metadata, or null if no metadata has been set 
2getExampleMetaData(java.lang.Class<T> metaDataType)Get the example metadata, or null if no metadata has been set
Note: this method results in an unchecked cast - care should be taken when using this!
 
3
getFeatures()
Returns the features array for the DataSet 
4
getLabels()
Returns the labels for the dataset 

接口Evaluation

Evaluation metrics:
- precision, recall, f1, fBeta, accuracy, Matthews correlation coefficient, gMeasure
- Top N accuracy (if using constructor Evaluation(List, int))
- Custom binary evaluation decision threshold (use constructor Evaluation(double) (default if not set is argmax / 0.5)
- Custom cost array, using Evaluation(INDArray) or Evaluation(List, INDArray) for multi-class 

Note: Care should be taken when using the Evaluation class for binary classification metrics such as F1, precision, recall, etc. There are a number of cases to consider:
1. For binary classification (1 or 2 network outputs)
a) Default behaviour: class 1 is assumed as the positive class. Consequently, no-arg methods such as f1()precision()recall() etc will report the binary metric for class 1 only
b) To set class 0 as the positive class instead of class 1 (the default), use Evaluation(int, Integer) or Evaluation(double, Integer) or #setBinaryPositiveClass(Integer). Then, f1()precision()recall() etc will report the binary metric for class 0 only.
c) To use macro-averaged metrics over both classes for binary classification (uncommon and usually not advisable) specify 'null' as the argument (instead of 0 or 1) as per (b) above
2. For multi-class classification, binary metric methods such as f1()precision()recall() will report macro-average (of the one-vs-all) binary metrics. Note that you can specify micro vs. macro averaging using f1(EvaluationAveraging) and similar methods

Note that setting a custom binary decision threshold is only possible for the binary case (1 or 2 outputs) and cannot be used if the number of classes exceeds 2. Predictions with probability > threshold are considered to be class 1, and are considered class 0 otherwise.

Cost arrays (a row vector, of size equal to the number of outputs) modify the evaluation process: instead of simply doing predictedClass = argMax(probabilities), we do predictedClass = argMax(cost * probabilities). Consequently, an array of all 1s (or, indeed any array of equal values) will result in the same performance as no cost array; non- equal values will bias the predictions for or against certain classes.

        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatures());//Label the probabilities of the input
        eval.eval(testData.getLabels(), output, testMetaData);          //Note we are passing in the test set metadata here
        log.info("\n{}\n",output);
        log.info("统计报告为:"+eval.stats());//以字符串形式报告分类统计信息


       List<Prediction> predictionErrors = eval.getPredictionErrors();
        System.out.println("\n\n+++++ Prediction Errors +++++");
        for(Prediction p : predictionErrors){
            System.out.println("Predicted class: " + p.getPredictedClass() + ", Actual class: " + p.getActualClass()
                + "\t" + p.getRecordMetaData(RecordMetaData.class).getLocation());
        }


          List<Prediction> list1 = eval.getPredictions(1,2);                  //Predictions: actual class 1, predicted class 2
          List<Prediction> list2 = eval.getPredictionByPredictedClass(2);     //All predictions for predicted class 2
         List<Prediction> list3 = eval.getPredictionsByActualClass(2);       //All predictions for actual class 2
序号方法返回信息备注
1Evaluation(int numClasses)

构造函数

The number of classes to account for in the evaluation

 
2eval(INDArray realOutcomes, INDArray guesses, java.util.List<? extends java.io.Serializable> recordMetaData)Evaluate the network, with optional metadata 
3getPredictionErrors()Get a list of prediction errors, on a per-record basis 
4getPredictions(int actualClass, int predictedClass)

Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)

 
5
getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data class.

Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List) Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, via getConfusionMatrix()

 
6
getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted class.

Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List) Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, via getConfusionMatrix()

 
7
stats()
Report the classification statistics as a String 

 

stats()方法返回结果示例

========================Evaluation Metrics 评估指标========================
 # of classes:    3
 Accuracy(精准度):        0.6038
 Precision(准确率,查准率):       0.7083            
 Recall(召回率,查全率):          0.6845
 F1 Score:        0.5852
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)


=========================Confusion Matrix 混淆矩阵也称误差矩阵=========================
  0  1  2
----------
 17  0  0 | 0 = 0
  0 12  1 | 1 = 1
  0 20  3 | 2 = 2

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值