算法介绍
决策树算法是一种常见的机器学习算法,用于分类和回归问题。它基于一系列的规则和条件,构建一棵树状结构来进行预测或决策。
训练过程
- 特征选择:从训练数据中选择一个最佳的特征作为根节点,将数据集分为不同的子集。
- 分支节点:对每个子集,重复进行特征选择的过程,选择一个最佳的特征作为当前子集的分支节点,继续划分数据集。
- 叶节点:重复上述过程,直到满足停止条件,例如达到最大深度、样本数量小于阈值等。此时,生成叶节点,用于分类或回归的预测。
特征选择方式
- 信息增益(Information Gain):信息增益是基于信息论的概念,用于衡量特征对数据集划分后整体熵减少的程度。计算信息增益时,首先计算划分前的熵,然后计算每个特征划分后的加权平均熵,最后用划分前熵减去加权平均熵得到信息增益。信息增益越大,表示特征对于分类的贡献越大。
- 基尼不纯度(Gini Impurity):基尼不纯度是用于衡量特征对数据集划分后类别不纯度减少的程度。计算基尼不纯度时,首先计算划分前的基尼系数,然后计算每个特征划分后的加权平均基尼系数,最后用划分前基尼系数减去加权平均基尼系数得到基尼指数。基尼指数越小,表示特征对于分类的贡献越大。
- 增益率(Gain Ratio):增益率是在信息增益的基础上加入了对特征取值数目多少的惩罚,以防止特征取值过多对信息增益的影响。计算增益率时,除了计算信息增益,还要除以划分前特征的熵,最后用信息增益除以划分前特征的熵得到增益率。
- 卡方检验(Chi-square test):卡方检验是用于衡量特征与类别之间的相关性。它基于统计学的卡方检验方法,通过计算特征与类别之间的卡方统计量和显著性水平,来判断特征是否与类别相关。
算法优点
- 直观易解释,可以生成可视化的决策过程。
- 能够处理离散特征和连续特征。
- 可以处理多类别问题。
- 在一定条件下,对缺失值和异常值具有较好的容忍性。
算法缺点
- 对噪声和过拟合敏感,容易产生复杂的树结构。
- 对于包含大量特征和类别的数据集,决策树可能过于复杂,导致计算和存储开销增加。
- 决策树是一种贪心算法,可能无法找到全局最优解。
代码设计
构造方法
/**
********************
* The constructor.
*
* @param paraFilename
* The given file.
********************
*/
public ID3(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
} // Of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
} // Of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
} // Of for i
// Initialize.
children = null;
// Determine the label by simple voting.
label = getMajorityClass(availableInstances);
// Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}// Of the first constructor
先读取数据,再为主要属性赋值
投票选择最多标签
/**
**********************************
* Compute the majority class of the given block for voting.
*
* @param paraBlock
* The block.
* @return The majority class.
**********************************
*/
public int getMajorityClass(int[] paraBlock) {
int[] tempClassCounts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassCounts[(int) dataset.instance(paraBlock[i]).classValue()]++;
} // Of for i
int resultMajorityClass = -1;
int tempMaxCount = -1;
for (int i = 0; i < tempClassCounts.length; i++) {
if (tempMaxCount < tempClassCounts[i]) {
resultMajorityClass = i;
tempMaxCount = tempClassCounts[i];
} // Of if
} // Of for i
return resultMajorityClass;
}// Of getMajorityClass
这个方法可以投票选择最多的标签
判纯
/**
**********************************
* Is the given block pure?
*
* @param paraBlock
* The block.
* @return True if pure.
**********************************
*/
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0])
.classValue()) {
pure = false;
break;
} // Of if
} // Of for i
return pure;
}// Of pureJudge
判断该数组的标签是否为纯的
选择熵最低的标签
/**
**********************************
* Select the best attribute.
*
* @return The best attribute index.
**********************************
*/
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionalEntropy(availableAttributes[i]);
if (tempMinimalEntropy > tempEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
} // Of if
} // Of for i
return splitAttribute;
}// Of selectBestAttribute
计算单个熵
/**
**********************************
* Compute the conditional entropy of an attribute.
*
* @param paraAttribute
* The given attribute.
*
* @return The entropy.
**********************************
*/
public double conditionalEntropy(int paraAttribute) {
// Step 1. Statistics.
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
} // Of for i
// Step 2.
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
} // Of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
} // Of if
tempEntropy += -tempFraction * Math.log(tempFraction);
} // Of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
} // Of for i
return resultEntropy;
}// Of conditionalEntropy
属性集分割
/**
**********************************
* Split the data according to the given attribute.
*
* @return The blocks.
**********************************
*/
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
// System.out.println("Dataset " + dataset + "\r\n");
// System.out.println("Attribute " + paraAttribute + " has " +
// tempNumValues + " values.\r\n");
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
// First scan to count the size of each block.
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
} // Of for i
// Allocate space.
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
} // Of for i
// Second scan to fill.
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
// Copy data.
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
} // Of for i
return resultBlocks;
}// Of splitData
树的创建
/**
**********************************
* Build the tree recursively.
**********************************
*/
public void buildTree() {
if (pureJudge(availableInstances)) {
return;
} // Of if
if (availableInstances.length <= smallBlockThreshold) {
return;
} // Of if
selectBestAttribute();
int[][] tempSubBlocks = splitData(splitAttribute);
children = new ID3[tempSubBlocks.length];
// Construct the remaining attribute set.
int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
for (int i = 0; i < availableAttributes.length; i++) {
if (availableAttributes[i] < splitAttribute) {
tempRemainingAttributes[i] = availableAttributes[i];
} else if (availableAttributes[i] > splitAttribute) {
tempRemainingAttributes[i - 1] = availableAttributes[i];
} // Of if
} // Of for i
// Construct children.
for (int i = 0; i < children.length; i++) {
if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
children[i] = null;
continue;
} else {
// System.out.println("Building children #" + i + " with
// instances " + Arrays.toString(tempSubBlocks[i]));
children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
// Important code: do this recursively
children[i].buildTree();
} // Of if
} // Of for i
}// Of buildTree
/**
********************
* The constructor.
*
* @param paraDataset
* The given dataset.
********************
*/
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
// Copy its reference instead of clone the availableInstances.
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
// Initialize.
children = null;
// Determine the label by simple voting.
label = getMajorityClass(availableInstances);
// Determine whether or not it is pure.
pure = pureJudge(availableInstances);
}// Of the second constructor
测试
/**
**********************************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The prediction.
**********************************
*/
public int classify(Instance paraInstance) {
if (children == null) {
return label;
} // Of if
ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
if (tempChild == null) {
return label;
} // Of if
return tempChild.classify(paraInstance);
}// Of classify
/**
**********************************
* Test on a testing set.
*
* @param paraDataset
* The given testing data.
* @return The accuracy.
**********************************
*/
public double test(Instances paraDataset) {
double tempCorrect = 0;
for (int i = 0; i < paraDataset.numInstances(); i++) {
if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
tempCorrect++;
} // Of i
} // Of for i
return tempCorrect / paraDataset.numInstances();
}// Of test