决策树是最经典的机器学习算法,首先复习一些学习决策树过程中会用到的基本理论。
基本理论
1.信息熵
首先了解什么是信息,什么是熵。
熵:事件的不确定性称为熵。不确定性越大,熵越大。确定事件的熵为 0。
事件各种可能性的分布越均匀,熵越大,因为结果越不确定。事件各种可能性的分布越不均匀,熵越小。因为结果有偏向性,不确定性在减少。
信息:通过获取额外的信息可以增加事件的确定性,即消除一部分熵。消除了多少熵就是获得了多少信息。
信息熵:消除不确定性所需信息量的度量。
现在继续学习常用的信息熵公式。
抛一次硬币有正反两种情况,正反各占 1/2,这个不确定事件的信息熵我们定义为 1 bit。
抛三次硬币有种情况。
对于有 8 种可能性的等概率事件, 熵为 bit。
假设等概率事件的可能性个数为n个,则熵为。
对于非等概率事件来说,事件的熵应该是各个选项的概率乘以各自的熵,然后加和。
在等概率事件中,如果有 10种可能,则每种可能的概率为 1/10。也就是说,等可能事件的概率 p 为可能性个数 n 的倒数。
那么在非等概率事件中,我们可以把概率 p 的倒数 1/p 看做等概率事件的可能性个数 n。我们就得到一般事件的熵:
其中,n为事件可能性的个数。
化简后可得:
这就是常用的信息熵公式。
2.信息增益
决策树中属性划分的标准是让每个分支纯度更高,其实就是尽可能增加分类的确定性。而熵表示了事件的不确定性,消除熵可以增加事件的确定性,所以只需计算划分前后熵的变化。
划分前事件X的熵:。
按照属性 A 划分后事件X的熵:。
则 就是划分之后熵的变化。
信息增益的由来是消除了多少熵就相当于增加了多少信息,
所以信息增益
依次计算各个属性的信息增益,选择信息增益最大的那个属性来划分分支即可。当然,根据属性的信息熵大小划分也可以。
决策树
先简单学习一下决策树的概念:决策树是一棵类似 if-else 的判断树,为决策而构建的树。
我们举例所用的的数据集如下:
@relation weather
@attribute Outlook {Sunny, Overcast, Rain}
@attribute Temperature {Hot, Mild, Cool}
@attribute Humidity {High, Normal, Low}
@attribute Windy {FALSE, TRUE}
@attribute Play {N, P}
@data
Sunny,Hot,High,FALSE,N
Sunny,Hot,High,TRUE,N
Overcast,Hot,High,FALSE,P
Rain,Mild,High,FALSE,P
Rain,Cool,Normal,FALSE,P
Rain,Cool,Normal,TRUE,N
Overcast,Cool,Normal,TRUE,P
Sunny,Mild,High,FALSE,N
Sunny,Cool,Normal,FALSE,P
Rain,Mild,Normal,FALSE,P
Sunny,Mild,Normal,TRUE,P
Overcast,Mild,High,TRUE,P
Overcast,Hot,Normal,FALSE,P
Rain,Mild,High,TRUE,N
从数据集中可以得出如下图所示的决策树。
那么如何使用决策树呢?
举例说明:以数据集中的第一条为例,首先可知Outlook = Sunny。由决策树中所示,继续判断Humidity,再由Humidity = High可知Play =Yes。
接下来举例讲解按照属性A划分后事件X的信息熵如何计算:
设A为Outlook。
设有5天Outlook = Sunny,其中3天Play= No, 2天Play= Yes;
4天Outlook = Overcast,其中0天Play= No ,4天Play= Yes;
5天Outlook = Rain,其中2天Play= No, 3天Play= Yes;
我们可以计算:
算法步骤:1.读取数据
2.根据属性的信息熵大小选取属性作为父节点
3.从可选择属性中删除该属性,避免重复选择
4.根据信息熵选取属性递归建立决策树
输入:weather.arff数据集
输出:算法准确度。
优化目标:可能没有优化目标。
代码如下:
package knn5;
import java.io.FileReader;
import java.util.Arrays;
import weka.core.*;
public class ID3 {
Instances dataset;
boolean pure;
int numClasses;
int[] availableInstances;
int[] availableAttributes;
int splitAttribute;
ID3[] children;
int label;
int[] predicts;
static int smallBlockThreshold = 3;
public ID3(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
} // Of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
} // Of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
} // Of for i
children = null;
label = getMajorityClass(availableInstances);
pure = pureJudge(availableInstances);
}// Of the first constructor
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
children = null;
label = getMajorityClass(availableInstances);
pure = pureJudge(availableInstances);
}// Of the second constructor
public boolean pureJudge(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
pure = false;
break;
} // Of if
} // Of for i
return pure;
}// Of pureJudge
public int getMajorityClass(int[] paraBlock) {
int[] tempClassCounts = new int[dataset.numClasses()];
for (int i = 0; i < paraBlock.length; i++) {
tempClassCounts[(int) dataset.instance(paraBlock[i]).classValue()]++;
} // Of for i
int resultMajorityClass = -1;
int tempMaxCount = -1;
for(int i = 0; i < tempClassCounts.length; i++) {
if (tempMaxCount < tempClassCounts[i]) {
resultMajorityClass = i;
tempMaxCount = tempClassCounts[i];
} // Of if
} // Of for i
return resultMajorityClass;
}// Of getMajorityClass
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinimalEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = conditionalEntropy(availableAttributes[i]);
if (tempMinimalEntropy > tempEntropy) {
tempMinimalEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
} // Of if
} // Of for i
return splitAttribute;
}// Of selectBestAttribute
public double conditionalEntropy(int paraAttribute) {
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
} // Of for i
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
} // Of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
} // Of if
tempEntropy += -tempFraction * Math.log(tempFraction);
} // Of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
} // Of for i
return resultEntropy;
}// Of conditionalEntropy
public int[][] splitData(int paraAttribute) {
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int[][] resultBlocks = new int[tempNumValues][];
int[] tempSizes = new int[tempNumValues];
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempSizes[tempValue]++;
} // Of for i
for (int i = 0; i < tempNumValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
} // Of for i
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
} // Of for i
return resultBlocks;
}// Of splitData
public void buildTree() {
if (pureJudge(availableInstances)) {
return;
} // Of if
if (availableInstances.length <= smallBlockThreshold) {
return;
} // Of if
selectBestAttribute();
int[][] tempSubBlocks = splitData(splitAttribute);
children = new ID3[tempSubBlocks.length];
int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
for (int i = 0; i < availableAttributes.length; i++) {
if (availableAttributes[i] < splitAttribute) {
tempRemainingAttributes[i] = availableAttributes[i];
} else if (availableAttributes[i] > splitAttribute) {
tempRemainingAttributes[i - 1] = availableAttributes[i];
} // Of if
} // Of for i
for (int i = 0; i < children.length; i++) {
if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
children[i] = null;
continue;
} else {
// System.out.println("Building children #" + i + " with
// instances " + Arrays.toString(tempSubBlocks[i]));
children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
// Important code: do this recursively
children[i].buildTree();
} // Of if
} // Of for i
}// Of buildTree
public int classify(Instance paraInstance) {
if (children == null) {
return label;
} // Of if
ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
if (tempChild == null) {
return label;
} // Of if
return tempChild.classify(paraInstance);
}// Of classify
public double test(Instances paraDataset) {
double tempCorrect = 0;
for (int i = 0; i < paraDataset.numInstances(); i++) {
if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
tempCorrect++;
} // Of i
} // Of for i
return tempCorrect / paraDataset.numInstances();
}// Of test
public double selfTest() {
return test(dataset);
}// Of selfTest
public static void id3Test() {
ID3 tempID3 = new ID3("C:\\\\Users\\\\ASUS\\\\Desktop\\\\文件\\\\weather.arff");
ID3.smallBlockThreshold = 3;
tempID3.buildTree();
double tempAccuracy = tempID3.selfTest();
System.out.println("The accuracy is: " + tempAccuracy);
}
public static void main(String[] args) {
id3Test();
}// Of main
}// Of class ID3
运行截图: