package machinelearning.decisiontree;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.*;
/**
* ******************************************
*
* @author Michelle Min MitchelleMin@163.com
* @date 2021-07-21
* ******************************************
*/
public class ID3 {
/*
The data.
*/
Instances dataset;
/*
Is this dataset pure (only one label)?
*/
boolean pure;
/*
The number of classes. For binary classification it is 2.
*/
int numClasses;
/*
Available instances. Other instances do not belong this branch.
*/
int[] availableInstances;
/*
Available attributes. Other attributes have been selected in the path
from the root.
*/
int[] availableAttributes;
/*
The selected attribute.
*/
int splitAttribute;
/*
The children nodes.
*/
ID3[] children;
/*
My label. Inner nodes also have a label. For example, <outlook = sunny,
humidity = high> never appear in the training data, but <humidity = high>
is valid in other cases.
*/
int label;
/*
The prediction, including queried and predicted labels.
*/
int[] predicts;
/*
Small block cannot be split further.
*/
static int smallBlockThreshold = 3;
/**
********************
* The constructor.
*
* @param paraFilename
* The given file.
********************
*/
public ID3(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
}//of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
}//of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
}//of for i
//Initialize.
children = null;
//Determine the label by simple voting.
label = getMajorityClass(availableInstances);
//Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}//of the first constructor
/**
********************
* The constructor.
*
* @param paraDataset
* The given dataset.
********************
*/
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
//Copy its reference instead of clone the availableInstances.
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
//Initialize.
children = null;
//Determine the label by simple voting.
label = getMajorityClass(availableInstances);
//Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}//of the second constructor
/**
********************
* Is the given block pure?
*
* @param paraBlock
* The block.
* @return True if pure.
********************
*/
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0])
.classValue()) {
pure = false;
break;
}//of if
}//of for i
return pure;
}//of pureJudge
/**
********************
* Compute the majority class of the given block for voting.
*
* @param paraBlock
* The block.
* @return The majority class.
********************
*/
public int getMajorityClass(int[] paraBlock) {
int[] tempClassCounts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassCounts[(int) dataset.instance(paraBlock[i]).classValue()]++;
}//of for i
int resultMajorityClass = -1;
int tempMaxCount = -1;
for (int i = 0; i < tempClassCounts.length; i++) {
if (tempMaxCount < tempClassCounts[i]) {
resultMajorityClass = i;
tempMaxCount = tempClassCounts[i];
}//of if
}//of for i
return resultMajorityClass;
}//of getMajorityClass
/**
********************
* Select the best attribute.
*
* @return The best attribute index.
********************
*/
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionalEntropy(availableAttributes[i]);
if (tempMinimalEntropy > tempEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
}//of if
}//of for i
return splitAttribute;
}//of selectBestAttribute
/**
********************
* Compute the conditional entropy of an attribute.
*
* @param paraAttribute
* The given attribute.
*
* @return The entropy.
********************
*/
public double conditionalEntropy(int paraAttribute) {
// Step 1. Statistics.
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
}//of for i
// Step 2.
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
}//of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
}//of if
tempEntropy += -tempFraction * Math.log(tempFraction);
}//of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
}//of for i
return resultEntropy;
}//of conditionalEntropy
/**
********************
* Split the data according to the given attribute.
*
* @return The blocks.
********************
*/
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
// System.out.println("Dataset " + dataset + "\r\n");
// System.out.println("Attribute " + paraAttribute + " has " + tempNumValues + " values.\r\n");
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
// First scan to count the size of each block.
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
}//of for i
// Allocate space.
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
}//of for i
// Second scan to fill.
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
// Copy data.
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
}//of for i
return resultBlocks;
}//of splitData
/**
*************************
* Test this class.
*
* @param args
* Not used now.
*************************
*/
public static void main(String[] args) {
ID3 tempID3 = new ID3("D:/mitchelles/data/weather.arff");
ID3.smallBlockThreshold = 3;
}// Of main
}//of class ID3
day61
最新推荐文章于 2024-06-05 08:50:53 发布