CMAR代码学习3 结合生成的classifier进行预测

CMAR代码学习1 生成FP-tree挖掘规则icon-default.png?t=M4ADhttps://blog.csdn.net/woshiwu6666/article/details/124717377CMAR代码学习2 计算卡方值并利用CR-tree剪枝icon-default.png?t=M4ADhttps://blog.csdn.net/woshiwu6666/article/details/124790979

数据集一共有数据14条

训练集9条

测试接5条

在Evauator当中

// For each fold, it does training and testing
		for (int i = 0; i < k; i++) {
//				//Partitioning database 
			int posStart = i * absoluteRatio; // start position of testing set
			int posEnd = posStart + absoluteRatio; // end position of testing set
			if (i == (k - 1)) { // if last fold we adjust the size to include all the left-over sequences
				posEnd = dataset.getInstances().size(); // special case
			}

			// Split the dataset in two parts
			Dataset[] datasets = VirtualDataset.splitDatasetForKFold(dataset, posStart, posEnd);
			Dataset training = datasets[0];
			Dataset testing = datasets[1];

			if (DEBUGMODE) {
				System.out.println("===== KFOLD " + i + " =====");
				System.out.println(" k = " + k);
				System.out.println("  - Original dataset: " + dataset.getInstances().size() + " records.");
				System.out.println("  - Training part: " + training.getInstances().size() + " records.");
				System.out.println("  - Testing part: " + testing.getInstances().size() + " records.");
				System.out.println("===== RUNNING =====");
			}

			// for each classifier
			for (ClassificationAlgorithm algorithm : algorithms) {
				if (DEBUGMODE) {
					System.out.println("Running algorithm ... " + algorithm.getName());
//						System.out.println(datasets[0].getMapClassToFrequency());
//						System.out.println(datasets[1].getMapClassToFrequency());
				}
				// Train the classifier
				Classifier classifier = algorithm.trainAndCalculateStats(training);
				TrainingResults trainResults = new TrainingResults();
				trainResults.memory += algorithm.getTrainingMaxMemory();
				trainResults.runtime += algorithm.getTrainingTime();
				if (classifier instanceof RuleClassifier) {
					trainResults.avgRuleCount += ((RuleClassifier) classifier).getNumberRules() / (double) k;
				}

				// Run on training set
				ClassificationResults resultsOnTraining = new ClassificationResults();
				runOnInstancesAnUpdateResults(training, classifier, resultsOnTraining);

				// Run on testing set
				ClassificationResults resultsOnTesting = new ClassificationResults();
				runOnInstancesAnUpdateResults(testing, classifier, resultsOnTesting);

				/** Save results for this classifier for this dataset */
				allResults.addResults(resultsOnTraining, resultsOnTesting, trainResults);
			}
		}
		return allResults;
	}

 其中,该部分用于预测训练集

// Run on training set
				ClassificationResults resultsOnTraining = new ClassificationResults();
				runOnInstancesAnUpdateResults(training, classifier, resultsOnTraining);

这部分用于预测测试集

// Run on testing set
				ClassificationResults resultsOnTesting = new ClassificationResults();
				runOnInstancesAnUpdateResults(testing, classifier, resultsOnTesting);

调用runOnInstancesAnUpdateResults函数

 short predictedKlassIndex = classifier.predict(instance);//预测值
 short realKlassIndex = instance.getKlass();//真实值

private void runOnInstancesAnUpdateResults(Dataset dataset, Classifier classifier, ClassificationResults results) {
 		MemoryLogger.getInstance().reset();
		long thisRuntime = System.currentTimeMillis();

		for (Instance instance : dataset.getInstances()) {
			short predictedKlassIndex = classifier.predict(instance);
			short realKlassIndex = instance.getKlass();

			results.predictedClasses.add(predictedKlassIndex);
			results.matrix.add(realKlassIndex, predictedKlassIndex);
		}
		results.runtime += System.currentTimeMillis() - thisRuntime;
		MemoryLogger.getInstance().checkMemory();
		results.memory += MemoryLogger.getInstance().getMaxMemory();
	}

在ClassifierCMAR.java当中,重写predict函数

分类结果分成三类

  1. 没有matching rules
  2. 有多条matching rules,但只有onlyOneClass()
  3. 有多条matching rules,有多个class,需要调用groupRulesByKlass(matchingRules)对matchingrules分类
 @Override
    public short predict(Instance instance) {
        List<RuleCMAR> matchingRules = obtainallRulesForRecord(instance.getItems());

        // If no rules satisfy record, it cannot be performed any prediction
        if (matchingRules.isEmpty()) {
            return NOPREDICTION;
        }

        // If only one rule return class
        if (matchingRules.size() == 1) {
            return matchingRules.get(0).getKlass();
        }

        // If more than one rule but all have the same class return calss
        if (onlyOneClass(matchingRules)) {
            return matchingRules.get(0).getKlass();
        }

        // Group rules
        Map<Short, List<RuleCMAR>> ruleGroups = groupRulesByKlass(matchingRules);

        // Weighted Chi-Squared (WCS) Values for each group and
        // Select group with best WCS value and return associated class
        return getClassWithBestChiQuareValue(ruleGroups);    
     }

onlyOneClass(),判断matching rules是否只有一个class 

 /**
     * Check if in specified rules there are more than one class
     * 
     * @param rules to check if they have more class
     * @return true if there are only one class, false otherwise
     */
    private boolean onlyOneClass(List<RuleCMAR> rules) {
        short firstKlass = rules.get(0).getKlass();
        for (int i = 1; i < rules.size(); i++) {
            if (rules.get(i).getKlass() != firstKlass) {
                return false;
            }
        }
        return true;
    }

 groupRulesByKlass(matchingRules),将不同class的matching rules进行分类

  /**
     * Forms groups of rules in function of its consequent
     * 
     * @param rules to be grouped by consequent
     * @return group of rules by klass
     */
    private Map<Short, List<RuleCMAR>> groupRulesByKlass(List<RuleCMAR> rules) {
    	Map<Short, List<RuleCMAR>> rulesByGroup = new HashMap<Short, List<RuleCMAR>>();
        for (RuleCMAR rule : rules) {
        	// Improved efficiency by Philippe
        	List<RuleCMAR> rulesForKlass = rulesByGroup.get(rule.getKlass());
            if (rulesForKlass == null) {
            	rulesForKlass = new ArrayList<RuleCMAR>();
                rulesByGroup.put(rule.getKlass(), rulesForKlass);
            }
            rulesForKlass.add(rule);
        }
        return rulesByGroup;
    }

以下是预测一条用于测试的例子

instance2 4 8 9 11 

通过黄色部分得到的matching rules 

0:[4] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
1:[2, 4] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
2:[2, 4, 8] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
3:[2, 4, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
4:[2, 4, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
5:[4, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
6:[4, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
7:[4, 8, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
8:[4, 8, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
9:[4, 8] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
10:[2] -> 11 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 2.25
11:[2, 9] -> 11 #SUP: 2 #CONF: 1.0 #CHISQUARE: 5.142857142857143
12:[2, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
13:[2, 8] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
14:[2, 8] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
15:[9] -> 12 #SUP: 3 #CONF: 0.6 #CHISQUARE: 0.22500000000000003
16:[8, 9] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
17:[8] -> 12 #SUP: 5 #CONF: 0.8333333333333334 #CHISQUARE: 2.25

调用 groupRulesByKlass(matchingRules)分类得到的结果,一个是分类为11,一个是分类为12

0:[2, 4] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
1:[2, 4, 8] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
2:[2, 4, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
3:[2, 4, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
4:[4, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
5:[4, 8, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
6:[2] -> 11 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 2.25
7:[2, 9] -> 11 #SUP: 2 #CONF: 1.0 #CHISQUARE: 5.142857142857143
8:[2, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
9:[2, 8] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145 

0:[4] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
1:[4, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
2:[4, 8, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
3:[4, 8] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
4:[2, 8] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
5:[9] -> 12 #SUP: 3 #CONF: 0.6 #CHISQUARE: 0.22500000000000003
6:[8, 9] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
7:[8] -> 12 #SUP: 5 #CONF: 0.8333333333333334 #CHISQUARE: 2.25 

随后,分别对这两个组,计算累计卡方值

 

 然后,卡方值高的作为这个case的分类,因此选择11作为它的分类

key11 wcsValue17.015625000000004
key12 wcsValue0.8125714285714285

 随后调用results.predictedClasses.add(predictedKlassIndex);results.matrix.add(realKlassIndex, predictedKlassIndex);将结果放进ClassificationResults当中保存记录,方便计算结果
 

class ClassificationResults {
		ConfusionMatrix matrix = new ConfusionMatrix();
		List<Short> predictedClasses = new ArrayList<Short>();
		long runtime = 0l;
		Double memory = 0d;
	}
public ConfusionMatrix() {
		this.matrix = new TreeMap<Short, Map<Short, Long>>();
	}

matrix.add(realKlassIndex, predictedKlassIndex),其中该函数有几个变量需要值的注意

predictedClasses:记录每条case的预测结果

allRealklasss:累计记录数据中实际出现有的class

allPredictedklasss:累计记录数据中实际预测的class

map与value::累计统计预测值的数量

nopredictions:统计预测错误的数量

correct:统计预测正确的数量

	/**
	 * Add both prediction and real value to the confusion matrix
	 * 
	 * @param realValue
	 * @param observedValue
	 */
	public void add(Short realValue, Short observedValue) {
		// ===============================================
		// PHILIPPE: I have optimized the code below... The maps was accessed numerous
		// times while it was unnecessary.
		// ======================================
		allRealklasss.add(realValue);
		allPredictedklasss.add(observedValue);

		Map<Short, Long> map = matrix.get(realValue);
		if (map == null) {
			map = new TreeMap<Short, Long>();
			matrix.put(realValue, map);
		}

		Long value = map.get(observedValue);
		if (value == null) {
			map.put(observedValue, 1l);
		} else {
			map.put(observedValue, value + 1);
		}

		this.total += 1;
		if (realValue.equals(Classifier.NOPREDICTION)) {
			nopredictions += 1;
		}

		if (realValue.equals(observedValue)) {
			this.correct += 1;
		}
	}

matrix用于统计预测值的分类情况

==========================================
matrix 
11{11=1}
==========================================
matrix 
11{11=1}12{12=1}
==========================================
matrix 
11{11=2}12{12=1}
==========================================
matrix 
11{11=2}12{12=2}
==========================================

OverallResults.java当中

调用addResults(ClassificationResults resultsOnTraining, 
            ClassificationResults resultsOnTesting, 
            TrainingResults trainResults)保存结果

public void addResults(ClassificationResults resultsOnTraining, 
			ClassificationResults resultsOnTesting, 
			TrainingResults trainResults) {
		if(trainResults != null) {
			runtimeToTrain.add(trainResults.runtime);
			memoryToTrain.add(trainResults.memory);
			avgRuleCount.add(trainResults.avgRuleCount);
		}
		
		if(resultsOnTraining != null) {
			listMatrixOnTraining.add(resultsOnTraining.matrix);
			runtimeOnTraining.add(resultsOnTraining.runtime);
			memoryUsageOnTraining.add(resultsOnTraining.memory);
			
		}
		if(resultsOnTesting != null) {
			listMatrixOnTesting.add(resultsOnTesting.matrix);
			predictedClasseOnTesting.add(resultsOnTesting.predictedClasses);
			runtimeOnTesting.add(resultsOnTesting.runtime);
			memoryUsageOnTesting.add(resultsOnTesting.memory);
		}
	}

计算结果,并将结果保存到文本当中

MainTestCMAR_batch_kfold.java

其中

allResults.saveMetricsResultsToFile(forTrainingPath, onTrainingPath, onTrestingPath);//保存结果到txt当中

allResults.printStats()//输出结果到console当中


		// We run the experiment
		OverallResults allResults = experiment1.trainAndRunClassifiersKFold(algorithms, dataset, kFoldCount);

		// Save statistics about the execution to files (optional)
		String forTrainingPath = "outputReportForTraining.txt";
		String onTrainingPath = "outputReportOnTraining.txt";
		String onTrestingPath = "outputReportOnTesting.txt";
		allResults.saveMetricsResultsToFile(forTrainingPath, onTrainingPath, onTrestingPath);

		// Print statistics to the console (optional)
		allResults.printStats();
/**
	 * Save metrics to a file
	 * 
	 * @param metricsReportPath the file path
	 */
	public void saveMetricsResultsToFile(String toTrainpath, String onTrainingPath, String onTestingPath) {
		try {
			if(toTrainpath != null) {
				PrintWriter metricsWriter = new PrintWriter(toTrainpath, "UTF-8");
				metricsWriter.write(trainingMetricsToString(runtimeToTrain, memoryToTrain));
				metricsWriter.close();
			}
			
			if(onTrainingPath != null) {
				PrintWriter metricsWriter = new PrintWriter(onTrainingPath, "UTF-8");
				metricsWriter.write(metricsToString(listMatrixOnTraining, runtimeOnTraining, memoryUsageOnTraining));
				metricsWriter.close();
			}
			if(onTestingPath != null) {
				PrintWriter metricsWriter = new PrintWriter(onTestingPath, "UTF-8");
				metricsWriter.write(metricsToString(listMatrixOnTesting, runtimeOnTesting, memoryUsageOnTesting));
				metricsWriter.close();
			}
		} catch (FileNotFoundException | UnsupportedEncodingException e) {
			e.printStackTrace();
		}
	}

调用 metricsToString(List<ConfusionMatrix> listMatrix,List<Long> runtimes,List<Double> memoryUsages)函数,计算结果

private String metricsToString(List<ConfusionMatrix> listMatrix,
			List<Long> runtimes,
			List<Double> memoryUsages) {
		StringBuilder builder = new StringBuilder();
		//========
		builder.append("#NAME:\t");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + names.get(i));
		}
		//========
		builder.append(System.lineSeparator());
		builder.append("#ACCURACY:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getAccuracy()));
		}
		builder.append(System.lineSeparator());
		//========
		builder.append("#RECALL:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getAverageRecall()));
		}
		builder.append(System.lineSeparator());
		//========
		builder.append("#PRECISION:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getAveragePrecision()));
		}
		builder.append(System.lineSeparator());
		//========
		builder.append("#KAPPA:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getKappa()));
		}
		builder.append(System.lineSeparator());
		//========
		builder.append("#FMICRO:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getMicroFMeasure()));
		}
		builder.append(System.lineSeparator());
		//========x
		builder.append("#FMACRO:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(listMatrix.get(i).getMacroFMeasure()));
		}
		builder.append(System.lineSeparator());

		//========x
		builder.append("#TIMEms:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + runtimes.get(i));
		}
		builder.append(System.lineSeparator());
		//========
		builder.append("#MEMORYmb:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + df.format(memoryUsages.get(i)));
		}
		builder.append(System.lineSeparator());
		//========x
		builder.append("#NOPREDICTION:");
		for(int i =0; i< algorithmCount; i++) {
			builder.append("\t" + listMatrix.get(i).getNopredictions());
		}
		
		return builder.toString();
	}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值