ID3算法相对简单,weka的实现也容易理解。首先介绍一下大致算法。算法概述如下。
1.选择一种度量(ID3选择的是信息增益),计算每个属性对于该度量的值。
2.根据结果选择一个属性进行分支。
3.如果每个分支全部属于一个类或者已经没有候选属性。则停止,否则对每个分支进行1,2操作。
下面对weka的ID3 class 作介绍,主要涉及到makeTree(Instances data),computeInfoGain(data, att),splitData(Instances data, Attribute att)三个函数。其中makeTree是入口函数,computeInfoGain的作用是计算信息增益,splitData的作用是分支。首先看makeTree函数。
private void makeTree(Instances data) throws Exception {
// Check if no instances have reached this node.
if (data.numInstances() == 0) {
m_Attribute = null;
m_ClassValue = Instance.missingValue();
m_Distribution = new double[data.numClasses()];
return;
}
// Compute attribute with maximum information gain.
double[] infoGains = new double[data.numAttributes()];
Enumeration attEnum = data.enumerateAttributes();
/**
* 对每个属性计算信息增益
*/
while (attEnum.hasMoreElements()) {
Attribute att = (Attribute) attEnum.nextElement();
infoGains[att.index()] = computeInfoGain(data, att);
}
m_Attribute = data.attribute(Utils.maxIndex(infoGains));
// Make leaf if information gain is zero.
// Otherwise create successors.
if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
m_Attribute = null;
m_Distribution = new double[data.numClasses()];
Enumeration instEnum = data.enumerateInstances();
while (instEnum.hasMoreElements()) {
Instance inst = (Instance) instEnum.nextElement();
m_Distribution[(int) inst.classValue()]++;
}
Utils.normalize(m_Distribution);
m_ClassValue = Utils.maxIndex(m_Distribution);
m_ClassAttribute = data.classAttribute();
} else {
Instances[] splitData = splitData(data, m_Attribute);
m_Successors = new Id3[m_Attribute.numValues()];
/**
* 这里对每个分支继续调用id3.makeTree(instatnces)。
*/
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new Id3();
m_Successors[j].makeTree(splitData[j]);
}
}
}
通过注释,应该不难理解大致过程。这里需要注意的是 程序里经常会出现Enumeration,这其实就是现在的Ieratorer,当时jdk版本较低,所以用的Enumeration,忽视掉就好了。
下面是splitData。只是按照类的值进行分支,也很容易理解。
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
double infoGain = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
infoGain -= ((double) splitData[j].numInstances() /
(double) data.numInstances()) *
computeEntropy(splitData[j]);
}
}
return infoGain;
}
至于 computeInfoGain对照公式就很容易理解了。这里只贴出代码
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
double infoGain = computeEntropy(data);
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
infoGain -= ((double) splitData[j].numInstances() /
(double) data.numInstances()) *
computeEntropy(splitData[j]);
}
}
return infoGain;
}