我们知道ID3是一个最基本的决策树算法。他主要是每次根据InfoGain来选取特征进行分裂,并且没有进行剪枝。
buildClassifier:
public void buildClassifier(Instances data) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(data);
// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
//递归构造决策树
makeTree(data);
}
这里没有什么好写的,只需看最后一行,makeTree这个函数。
makeTree:
// Check if no instances have reached this node.
// 如果该树为空,则返回
if (data.numInstances() == 0) {
m_Attribute = null;
m_ClassValue = Instance.missingValue();
m_Distribution = new double[data.numClasses()];
return;
}
检验此时的树是否已经为空,如果为空,说明递归完成。上一个分裂节点,就是叶子。
// Compute attribute with maximum information gain.
double[] infoGains = new double[data.numAttributes()];
Enumeration attEnum = data.enumerateAttributes();
while (attEnum.hasMoreElements()) {
Attribute att = (Attribute) attEnum.nextElement();
infoGains[att.index()] = computeInfoGain(data, att);
}
m_Attribute = data.attribute(Utils.maxIndex(infoGains));
这块就是计算每个属性的InfoGain,选出对应最大那个作为分裂属性。简单易懂的代码!(待会看看computeInfoGain这个函数)
// Make leaf if information gain is zero.
// Otherwise create successors.
if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
m_Attribute = null;
m_Distribution = new double[data.numClasses()];
Enumeration instEnum = data.enumerateInstances();
while (instEnum.hasMoreElements()) {
Instance inst = (Instance) instEnum.nextElement();
m_Distribution[(int) inst.classValue()]++;
}
Utils.normalize(m_Distribution);
m_ClassValue = Utils.maxIndex(m_Distribution);
m_ClassAttribute = data.classAttribute();
} else {
Instances[] splitData = splitData(data, m_Attribute);
m_Successors = new Id3[m_Attribute.numValues()];
for (int j = 0; j < m_Attribute.numValues(); j++) {
m_Successors[j] = new Id3();
m_Successors[j].makeTree(splitData[j]);
}
}
第一个判断就是问此时InfoGain是否为0,如果InfoGain=0,那么意味着这个时候,该节点已经是叶子(因为全部样本属于同一个class了!)。
那么,开始计算m_Distribution,其实这个m_Distribution没啥太大用处,因为这个子树的样本肯定属于同一类,其他类全是0.
如果InfoGain!=0,意味着还需要继续分类。那么,我们已经知道要分类的属性了,接下来只要根据该属性,将原来的数据分成几个部分(该属性有几种取值,就分成几个),然后再递归地调用makeTree即可。用m_Successors存储所有子树。
computeInfoGain:
private double computeInfoGain(Instances data, Attribute att)
throws Exception {
double infoGain = computeEntropy(data);
//若att有k种取值,则分成k个部分
Instances[] splitData = splitData(data, att);
for (int j = 0; j < att.numValues(); j++) {
if (splitData[j].numInstances() > 0) {
infoGain -= ((double) splitData[j].numInstances() /
(double) data.numInstances()) *
computeEntropy(splitData[j]);
}
}
return infoGain;
}
这个也很容易看懂,只要知道infoGain的计算公式:
H(D)就是该属性的熵(这个就不说了)
computeEntropy:
<span style="font-size:14px;">private double computeEntropy(Instances data) throws Exception {
//统计每个类各有多少样本
double [] classCounts = new double[data.numClasses()];
Enumeration instEnum = data.enumerateInstances();
while (instEnum.hasMoreElements()) {
Instance inst = (Instance) instEnum.nextElement();
classCounts[(int) inst.classValue()]++;
}
double entropy = 0;
for (int j = 0; j < data.numClasses(); j++) {
//classCounts等于0,那么这部分pi*log(pi)=0
if (classCounts[j] > 0) {
entropy -= classCounts[j] * Utils.log2(classCounts[j]);
}
}
//之前if里没有包含分母,这里除以原来公式中的分母
entropy /= (double) data.numInstances();
return entropy + Utils.log2(data.numInstances());
}</span>
都在注释里了。
splitData:
private Instances[] splitData(Instances data, Attribute att) {
Instances[] splitData = new Instances[att.numValues()];
for (int j = 0; j < att.numValues(); j++) {
//初始化,把数据信息给子树,这里不是复制data给splitData!
splitData[j] = new Instances(data, data.numInstances());
}
Enumeration instEnum = data.enumerateInstances();
while (instEnum.hasMoreElements()) {
Instance inst = (Instance) instEnum.nextElement();
//inst.value(att)返回的是inst对应该属性的值
splitData[(int) inst.value(att)].add(inst);
}
for (int i = 0; i < splitData.length; i++) {
splitData[i].compactify();
}
return splitData;
}
这里基本也是挺直观的,建立一个Instances数组,然后每个坑存放一个子集。这里这个inst.value(att)有点不理解的地方,也就是说,他已经把每个属性的值转换到0-k了。
那个compactify()就是把信息改下,使得数据和info对应起来。
weka的ID3基本就是这些函数了,然后我有个最大的感觉就是他处理的数据形式有限。目前还没找到,如何处理numeric的代码~~~奇怪奇怪!!!