CMAR代码学习1 生成FP-tree挖掘规则https://blog.csdn.net/woshiwu6666/article/details/124717377CMAR代码学习2 计算卡方值并利用CR-tree剪枝https://blog.csdn.net/woshiwu6666/article/details/124790979
数据集一共有数据14条
训练集9条
测试接5条
在Evauator当中
// For each fold, it does training and testing
for (int i = 0; i < k; i++) {
// //Partitioning database
int posStart = i * absoluteRatio; // start position of testing set
int posEnd = posStart + absoluteRatio; // end position of testing set
if (i == (k - 1)) { // if last fold we adjust the size to include all the left-over sequences
posEnd = dataset.getInstances().size(); // special case
}
// Split the dataset in two parts
Dataset[] datasets = VirtualDataset.splitDatasetForKFold(dataset, posStart, posEnd);
Dataset training = datasets[0];
Dataset testing = datasets[1];
if (DEBUGMODE) {
System.out.println("===== KFOLD " + i + " =====");
System.out.println(" k = " + k);
System.out.println(" - Original dataset: " + dataset.getInstances().size() + " records.");
System.out.println(" - Training part: " + training.getInstances().size() + " records.");
System.out.println(" - Testing part: " + testing.getInstances().size() + " records.");
System.out.println("===== RUNNING =====");
}
// for each classifier
for (ClassificationAlgorithm algorithm : algorithms) {
if (DEBUGMODE) {
System.out.println("Running algorithm ... " + algorithm.getName());
// System.out.println(datasets[0].getMapClassToFrequency());
// System.out.println(datasets[1].getMapClassToFrequency());
}
// Train the classifier
Classifier classifier = algorithm.trainAndCalculateStats(training);
TrainingResults trainResults = new TrainingResults();
trainResults.memory += algorithm.getTrainingMaxMemory();
trainResults.runtime += algorithm.getTrainingTime();
if (classifier instanceof RuleClassifier) {
trainResults.avgRuleCount += ((RuleClassifier) classifier).getNumberRules() / (double) k;
}
// Run on training set
ClassificationResults resultsOnTraining = new ClassificationResults();
runOnInstancesAnUpdateResults(training, classifier, resultsOnTraining);
// Run on testing set
ClassificationResults resultsOnTesting = new ClassificationResults();
runOnInstancesAnUpdateResults(testing, classifier, resultsOnTesting);
/** Save results for this classifier for this dataset */
allResults.addResults(resultsOnTraining, resultsOnTesting, trainResults);
}
}
return allResults;
}
其中,该部分用于预测训练集
// Run on training set
ClassificationResults resultsOnTraining = new ClassificationResults();
runOnInstancesAnUpdateResults(training, classifier, resultsOnTraining);
这部分用于预测测试集
// Run on testing set
ClassificationResults resultsOnTesting = new ClassificationResults();
runOnInstancesAnUpdateResults(testing, classifier, resultsOnTesting);
调用runOnInstancesAnUpdateResults函数
short predictedKlassIndex = classifier.predict(instance);//预测值
short realKlassIndex = instance.getKlass();//真实值
private void runOnInstancesAnUpdateResults(Dataset dataset, Classifier classifier, ClassificationResults results) {
MemoryLogger.getInstance().reset();
long thisRuntime = System.currentTimeMillis();
for (Instance instance : dataset.getInstances()) {
short predictedKlassIndex = classifier.predict(instance);
short realKlassIndex = instance.getKlass();
results.predictedClasses.add(predictedKlassIndex);
results.matrix.add(realKlassIndex, predictedKlassIndex);
}
results.runtime += System.currentTimeMillis() - thisRuntime;
MemoryLogger.getInstance().checkMemory();
results.memory += MemoryLogger.getInstance().getMaxMemory();
}
在ClassifierCMAR.java当中,重写predict函数
分类结果分成三类
- 没有matching rules
- 有多条matching rules,但只有onlyOneClass()
- 有多条matching rules,有多个class,需要调用groupRulesByKlass(matchingRules)对matchingrules分类
@Override
public short predict(Instance instance) {
List<RuleCMAR> matchingRules = obtainallRulesForRecord(instance.getItems());
// If no rules satisfy record, it cannot be performed any prediction
if (matchingRules.isEmpty()) {
return NOPREDICTION;
}
// If only one rule return class
if (matchingRules.size() == 1) {
return matchingRules.get(0).getKlass();
}
// If more than one rule but all have the same class return calss
if (onlyOneClass(matchingRules)) {
return matchingRules.get(0).getKlass();
}
// Group rules
Map<Short, List<RuleCMAR>> ruleGroups = groupRulesByKlass(matchingRules);
// Weighted Chi-Squared (WCS) Values for each group and
// Select group with best WCS value and return associated class
return getClassWithBestChiQuareValue(ruleGroups);
}
onlyOneClass(),判断matching rules是否只有一个class
/**
* Check if in specified rules there are more than one class
*
* @param rules to check if they have more class
* @return true if there are only one class, false otherwise
*/
private boolean onlyOneClass(List<RuleCMAR> rules) {
short firstKlass = rules.get(0).getKlass();
for (int i = 1; i < rules.size(); i++) {
if (rules.get(i).getKlass() != firstKlass) {
return false;
}
}
return true;
}
groupRulesByKlass(matchingRules),将不同class的matching rules进行分类
/**
* Forms groups of rules in function of its consequent
*
* @param rules to be grouped by consequent
* @return group of rules by klass
*/
private Map<Short, List<RuleCMAR>> groupRulesByKlass(List<RuleCMAR> rules) {
Map<Short, List<RuleCMAR>> rulesByGroup = new HashMap<Short, List<RuleCMAR>>();
for (RuleCMAR rule : rules) {
// Improved efficiency by Philippe
List<RuleCMAR> rulesForKlass = rulesByGroup.get(rule.getKlass());
if (rulesForKlass == null) {
rulesForKlass = new ArrayList<RuleCMAR>();
rulesByGroup.put(rule.getKlass(), rulesForKlass);
}
rulesForKlass.add(rule);
}
return rulesByGroup;
}
以下是预测一条用于测试的例子
instance2 4 8 9 11
通过黄色部分得到的matching rules
0:[4] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
1:[2, 4] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
2:[2, 4, 8] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
3:[2, 4, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
4:[2, 4, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
5:[4, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
6:[4, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
7:[4, 8, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
8:[4, 8, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
9:[4, 8] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
10:[2] -> 11 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 2.25
11:[2, 9] -> 11 #SUP: 2 #CONF: 1.0 #CHISQUARE: 5.142857142857143
12:[2, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
13:[2, 8] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
14:[2, 8] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
15:[9] -> 12 #SUP: 3 #CONF: 0.6 #CHISQUARE: 0.22500000000000003
16:[8, 9] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
17:[8] -> 12 #SUP: 5 #CONF: 0.8333333333333334 #CHISQUARE: 2.25
调用 groupRulesByKlass(matchingRules)分类得到的结果,一个是分类为11,一个是分类为12
0:[2, 4] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
1:[2, 4, 8] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
2:[2, 4, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
3:[2, 4, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
4:[4, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
5:[4, 8, 9] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
6:[2] -> 11 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 2.25
7:[2, 9] -> 11 #SUP: 2 #CONF: 1.0 #CHISQUARE: 5.142857142857143
8:[2, 8, 9] -> 11 #SUP: 1 #CONF: 1.0 #CHISQUARE: 2.2500000000000004
9:[2, 8] -> 11 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
0:[4] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
1:[4, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
2:[4, 8, 9] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
3:[4, 8] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
4:[2, 8] -> 12 #SUP: 1 #CONF: 0.5 #CHISQUARE: 0.32142857142857145
5:[9] -> 12 #SUP: 3 #CONF: 0.6 #CHISQUARE: 0.22500000000000003
6:[8, 9] -> 12 #SUP: 2 #CONF: 0.6666666666666666 #CHISQUARE: 0.0
7:[8] -> 12 #SUP: 5 #CONF: 0.8333333333333334 #CHISQUARE: 2.25
随后,分别对这两个组,计算累计卡方值
然后,卡方值高的作为这个case的分类,因此选择11作为它的分类
key11 wcsValue17.015625000000004
key12 wcsValue0.8125714285714285
随后调用results.predictedClasses.add(predictedKlassIndex);results.matrix.add(realKlassIndex, predictedKlassIndex);将结果放进ClassificationResults当中保存记录,方便计算结果
class ClassificationResults {
ConfusionMatrix matrix = new ConfusionMatrix();
List<Short> predictedClasses = new ArrayList<Short>();
long runtime = 0l;
Double memory = 0d;
}
public ConfusionMatrix() {
this.matrix = new TreeMap<Short, Map<Short, Long>>();
}
matrix.add(realKlassIndex, predictedKlassIndex),其中该函数有几个变量需要值的注意
predictedClasses:记录每条case的预测结果
allRealklasss:累计记录数据中实际出现有的class
allPredictedklasss:累计记录数据中实际预测的class
map与value::累计统计预测值的数量
nopredictions:统计预测错误的数量
correct:统计预测正确的数量
/**
* Add both prediction and real value to the confusion matrix
*
* @param realValue
* @param observedValue
*/
public void add(Short realValue, Short observedValue) {
// ===============================================
// PHILIPPE: I have optimized the code below... The maps was accessed numerous
// times while it was unnecessary.
// ======================================
allRealklasss.add(realValue);
allPredictedklasss.add(observedValue);
Map<Short, Long> map = matrix.get(realValue);
if (map == null) {
map = new TreeMap<Short, Long>();
matrix.put(realValue, map);
}
Long value = map.get(observedValue);
if (value == null) {
map.put(observedValue, 1l);
} else {
map.put(observedValue, value + 1);
}
this.total += 1;
if (realValue.equals(Classifier.NOPREDICTION)) {
nopredictions += 1;
}
if (realValue.equals(observedValue)) {
this.correct += 1;
}
}
matrix用于统计预测值的分类情况
==========================================
matrix
11{11=1}
==========================================
matrix
11{11=1}12{12=1}
==========================================
matrix
11{11=2}12{12=1}
==========================================
matrix
11{11=2}12{12=2}
==========================================
在OverallResults.java当中
调用addResults(ClassificationResults resultsOnTraining,
ClassificationResults resultsOnTesting,
TrainingResults trainResults)保存结果
public void addResults(ClassificationResults resultsOnTraining,
ClassificationResults resultsOnTesting,
TrainingResults trainResults) {
if(trainResults != null) {
runtimeToTrain.add(trainResults.runtime);
memoryToTrain.add(trainResults.memory);
avgRuleCount.add(trainResults.avgRuleCount);
}
if(resultsOnTraining != null) {
listMatrixOnTraining.add(resultsOnTraining.matrix);
runtimeOnTraining.add(resultsOnTraining.runtime);
memoryUsageOnTraining.add(resultsOnTraining.memory);
}
if(resultsOnTesting != null) {
listMatrixOnTesting.add(resultsOnTesting.matrix);
predictedClasseOnTesting.add(resultsOnTesting.predictedClasses);
runtimeOnTesting.add(resultsOnTesting.runtime);
memoryUsageOnTesting.add(resultsOnTesting.memory);
}
}
计算结果,并将结果保存到文本当中
MainTestCMAR_batch_kfold.java
其中
allResults.saveMetricsResultsToFile(forTrainingPath, onTrainingPath, onTrestingPath);//保存结果到txt当中
allResults.printStats()//输出结果到console当中
// We run the experiment
OverallResults allResults = experiment1.trainAndRunClassifiersKFold(algorithms, dataset, kFoldCount);
// Save statistics about the execution to files (optional)
String forTrainingPath = "outputReportForTraining.txt";
String onTrainingPath = "outputReportOnTraining.txt";
String onTrestingPath = "outputReportOnTesting.txt";
allResults.saveMetricsResultsToFile(forTrainingPath, onTrainingPath, onTrestingPath);
// Print statistics to the console (optional)
allResults.printStats();
/**
* Save metrics to a file
*
* @param metricsReportPath the file path
*/
public void saveMetricsResultsToFile(String toTrainpath, String onTrainingPath, String onTestingPath) {
try {
if(toTrainpath != null) {
PrintWriter metricsWriter = new PrintWriter(toTrainpath, "UTF-8");
metricsWriter.write(trainingMetricsToString(runtimeToTrain, memoryToTrain));
metricsWriter.close();
}
if(onTrainingPath != null) {
PrintWriter metricsWriter = new PrintWriter(onTrainingPath, "UTF-8");
metricsWriter.write(metricsToString(listMatrixOnTraining, runtimeOnTraining, memoryUsageOnTraining));
metricsWriter.close();
}
if(onTestingPath != null) {
PrintWriter metricsWriter = new PrintWriter(onTestingPath, "UTF-8");
metricsWriter.write(metricsToString(listMatrixOnTesting, runtimeOnTesting, memoryUsageOnTesting));
metricsWriter.close();
}
} catch (FileNotFoundException | UnsupportedEncodingException e) {
e.printStackTrace();
}
}
调用 metricsToString(List<ConfusionMatrix> listMatrix,List<Long> runtimes,List<Double> memoryUsages)函数,计算结果
private String metricsToString(List<ConfusionMatrix> listMatrix,
List<Long> runtimes,
List<Double> memoryUsages) {
StringBuilder builder = new StringBuilder();
//========
builder.append("#NAME:\t");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + names.get(i));
}
//========
builder.append(System.lineSeparator());
builder.append("#ACCURACY:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getAccuracy()));
}
builder.append(System.lineSeparator());
//========
builder.append("#RECALL:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getAverageRecall()));
}
builder.append(System.lineSeparator());
//========
builder.append("#PRECISION:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getAveragePrecision()));
}
builder.append(System.lineSeparator());
//========
builder.append("#KAPPA:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getKappa()));
}
builder.append(System.lineSeparator());
//========
builder.append("#FMICRO:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getMicroFMeasure()));
}
builder.append(System.lineSeparator());
//========x
builder.append("#FMACRO:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(listMatrix.get(i).getMacroFMeasure()));
}
builder.append(System.lineSeparator());
//========x
builder.append("#TIMEms:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + runtimes.get(i));
}
builder.append(System.lineSeparator());
//========
builder.append("#MEMORYmb:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + df.format(memoryUsages.get(i)));
}
builder.append(System.lineSeparator());
//========x
builder.append("#NOPREDICTION:");
for(int i =0; i< algorithmCount; i++) {
builder.append("\t" + listMatrix.get(i).getNopredictions());
}
return builder.toString();
}