最近由于实际需要,要把以前看过的算法复习下,我基本每行代码都按自己的意思理解了下,不知道对不对,不过贴出重点部分,供大家围观。第一篇先找简单的下手,嘿嘿,ID3,来吧。 /** 变量定义: */ /** 保存决策树节点的数组 */ private Id3[] m_Successors; /** 分裂属性 */ private Attribute m_Attribute; /** 叶子节点的分类 */ private double m_ClassValue; /** 叶子节点每种分类占的比例 */ private double[] m_Distribution; /** 类属性 */ private Attribute m_ClassAttribute; /** 主要函数: */ public void buildClassifier(Instances data) throws Exception { // 验证数据类型是否为算法所支持的类型 getCapabilities().testWithFail(data); // 读入数据集中的数据 data = new Instances(data); // 删除不完整的数据 data.deleteWithMissingClass(); // 开始建立ID3决策树 makeTree(data); } private void makeTree(Instances data) throws Exception { // 如果当前节点不包含任何实例 if (data.numInstances() == 0) { // 分裂属性为空,即不再分裂 m_Attribute = null; // 该节点类属性为空 m_ClassValue = Instance.missingValue(); // 按照最终分类的种类数创建空间 m_Distribution = new double[data.numClasses()]; return; } // 创建储存所有分类信息增益的数组 double[] infoGains = new double[data.numAttributes()]; // 创建用于遍历属性的迭代器 Enumeration attEnum = data.enumerateAttributes(); // 遍历每个属性,计算信息增益 while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } // 分裂属性为信息增益最大的属性 m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // 如果最大信息增益为0 if (Utils.eq(infoGains[m_Attribute.index()], 0)) { // 分裂属性为空 m_Attribute = null; // 按照类的种类数,创建空间 m_Distribution = new double[data.numClasses()]; // 创建用于遍历实例的迭代器 Enumeration instEnum = data.enumerateInstances(); // 遍历每个实例,统计类属性的分类情况 while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } // 将数据归一化到0,1区间内,实际上这步个人认为没必要 Utils.normalize(m_Distribution); // 当前节点的类属性设置为覆盖实例最多的属性 m_ClassValue = Utils.maxIndex(m_Distribution); // 当前节点的类属性为数据集的类属性 m_ClassAttribute = data.classAttribute(); // 如果最大信息增益不为0,则继续分裂 } else { // 按分裂属性进行分裂 Instances[] splitData = splitData(data, m_Attribute); // 按照节点的分裂个数创建空间 m_Successors = new Id3[m_Attribute.numValues()]; // 对于每个分裂出的节点 for (int j = 0; j < m_Attribute.numValues(); j++) { // 构造ID3分类器 m_Successors[j] = new Id3(); // 对于分裂出的实例继续分类 m_Successors[j].makeTree(splitData[j]); } } } // 计算信息增益,算法化简后和公式一样的效果 private double computeInfoGain(Instances data, Attribute att) throws Exception { double infoGain = computeEntropy(data); Instances[] splitData = splitData(data, att); for (int j = 0; j < att.numValues(); j++) { if (splitData[j].numInstances() > 0) { infoGain -= ((double) splitData[j].numInstances() / (double) data.numInstances()) * computeEntropy(splitData[j]); } } return infoGain; }