ID3算法的java实现(转载,修改)

转自 http://blog.csdn.net/luowen3405/article/details/6250731

    决策树是以实例为基础的归纳学习算法。它从一组无次序、无规则的元组中推理出决策树表示形式的分类规则。它采用自顶向下的递归方式,在决策树的内部结点进行属性值的比较,并根据不同的属性值从该结点向下分支,叶结点是要学习划分的类。从根到叶结点的一条路径就对应着一条合取规则,整个决策树就对应着一组析取表达式规则。

       1986Quinlan提出了著名的ID3算法。在ID3算法的基础上,1993Quinlan又提出了C4.5算法。为了适应处理大规模数据集的需要,后来又提出了若干改进的算法,其中SLIQ (super-vised learning in quest)SPRINT (scalable parallelizableinduction of decision trees)是比较有代表性的两个算法。

 (1) ID3算法

ID3算法的核心是:在决策树各级结点上选择属性时,用信息增益(information gain)作为属性的选择标准,以使得在每一个非叶结点进行测试时,能获得关于被测试记录最大的类别信息。其具体方法是:检测所有的属性,选择信息增益最大的属性产生决策树结点,由该属性的不同取值建立分支,再对各分支的子集递归调用该方法建立决策树结点的分支,直到所有子集仅包含同一类别的数据为止。最后得到一棵决策树,它可以用来对新的样本进行分类。

ID3算法的优点是:

算法的理论清晰,方法简单,学习能力较强。其缺点是:只对比较小的数据集有效,且对噪声比较敏感,当训练数据集加大时,决策树可能会随之改变。

(2) C4.5算法

C4.5算法继承了ID3算法的优点,并在以下几方面对ID3算法进行了改进:

1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足;

2) 在树构造过程中进行剪枝;

3) 能够完成对连续属性的离散化处理;

4) 能够对不完整数据进行处理。


C4.5算法与其它分类算法如统计方法、神经网络等比较起来有如下优点:产生的分类规则易于理解,准确率较高。其缺点是:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。此外,C4.5只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。


(3) SLIQ算法

SLIQ算法对C4.5决策树分类算法的实现方法进行了改进,在决策树的构造过程中采用了“预排序”和“广度优先策略”两种技术。

1) 预排序。对于连续属性在每个内部结点寻找其最优分裂标准时,都需要对训练集按照该属性的取值进行排序,而排序是很浪费时间的操作。为此,SLIQ算法采用了预排序技术。所谓预排序,就是针对每个属性的取值,把所有的记录按照从小到大的顺序进行排序,以消除在决策树的每个结点对数据集进行的排序。具体实现时,需要为训练数据集的每个属性创建一个属性列表,为类别属性创建一个类别列表。


2) 广度优先策略。在C4.5算法中,树的构造是按照深度优先策略完成的,需要对每个属性列表在每个结点处都进行一遍扫描,费时很多,为此,SLIQ采用广度优先策略构造决策树,即在决策树的每一层只需对每个属性列表扫描一次,就可以为当前决策树中每个叶子结点找到最优分裂标准。

SLIQ算法由于采用了上述两种技术,使得该算法能够处理比C4.5大得多的训练集,在一定范围内具有良好的随记录个数和属性个数增长的可伸缩性。


然而它仍然存在如下缺点:

1)由于需要将类别列表存放于内存,而类别列表的元组数与训练集的元组数是相同的,这就一定程度上限制了可以处理的数据集的大小。

2) 由于采用了预排序技术,而排序算法的复杂度本身并不是与记录个数成线性关系,因此,使得SLIQ算法不可能达到随记录数目增长的线性可伸缩性。


(4) SPRINT算法

为了减少驻留于内存的数据量,SPRINT算法进一步改进了决策树算法的数据结构,去掉了在SLIQ中需要驻留于内存的类别列表,将它的类别列合并到每个属性列表中。这样,在遍历每个属性列表寻找当前结点的最优分裂标准时,不必参照其他信息,将对结点的分裂表现在对属性列表的分裂,即将每个属性列表分成两个,分别存放属于各个结点的记录。


 SPRINT算法的优点是在寻找每个结点的最优分裂标准时变得更简单。其缺点是对非分裂属性的属性列表进行分裂变得很困难。解决的办法是对分裂属性进行分裂时用哈希表记录下每个记录属于哪个孩子结点,若内存能够容纳下整个哈希表,其他属性列表的分裂只需参照该哈希表即可。由于哈希表的大小与训练集的大小成正比,当训练集很大时,哈希表可能无法在内存容纳,此时分裂只能分批执行,这使得SPRINT算法的可伸缩性仍然不是很好。

本人对ID3的算法实现做了如下假设与处理:

 

1. 假设所有的属性值域都是分类型或名词离散型的

2.求信息增益时,log函数本来应以2为底,但是为了方便起见,直接调用了java.util.Math类中的以e为底的log函数,无论以什么为底均不会对影响结果产生影响

3.最后的输出并没有以树结构的形式给出,但是可以根据输出结果分析出决策树的结构

 

java实现代码如下

决策树结点类 class TreeNode

package DecisionTree;  
import java.util.ArrayList;  
/** 
 * 决策树结点类 
 * @author mgq 
 * @data 2012.01.09 
 */  
public class TreeNode {  
    private String name; //节点名(分裂属性的名称)  
    private ArrayList<String> rule; //结点的分裂规则  
    ArrayList<TreeNode> child; //子结点集合  
    private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组  
    private ArrayList<String> candAttr; //划分到该结点的候选属性  
    
    public TreeNode() {  
        this.name = "";  
        this.rule = new ArrayList<String>();  
        this.child = new ArrayList<TreeNode>();  
        this.datas = null;  
        this.candAttr = null;  
    }  
    public ArrayList<TreeNode> getChild() {  
        return child;  
    }  
    public void setChild(ArrayList<TreeNode> child) {  
        this.child = child;  
    }  
    public ArrayList<String> getRule() {  
        return rule;  
    }  
    public void setRule(ArrayList<String> rule) {  
        this.rule = rule;  
    }  
    public String getName() {  
        return name;  
    }  
    public void setName(String name) {  
        this.name = name;  
    }  
    public ArrayList<ArrayList<String>> getDatas() {  
        return datas;  
    }  
    
    public ArrayList<String> getCandAttr() {  
        return candAttr;  
    }  
    public void setCandAttr(ArrayList<String> candAttr) {  
        this.candAttr = candAttr;  
    }
	public void setDatas(ArrayList<ArrayList<String>> datas2) {
		// TODO Auto-generated method stub
		this.datas = datas2;  
		
	}  
}  

决策树构造类 class DecisionTree
package DecisionTree;  
import java.util.ArrayList;  
import java.util.HashMap;  
import java.util.Iterator;  
import java.util.Map;  

/** 
 * 决策树构造类 
 * @author mgq 
 * @data 2012.01.09 
 */  
public class DecisionTree {   
    private Integer attrSelMode;  //最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2  
    public DecisionTree(){  
        this.attrSelMode = 1;  
    }  
      
    public DecisionTree(int attrSelMode) {  
        this.attrSelMode = attrSelMode;  
    }  
    public void setAttrSelMode(Integer attrSelMode) {  
        this.attrSelMode = attrSelMode;  
    }  
    /** 
     * 获取指定数据集中的类别及其计数 
     * @param datas 指定的数据集 
     * @return 类别及其计数的map 
     */  
    public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas){  
        Map<String, Integer> classes = new HashMap<String, Integer>();  
        String c = "";  
        ArrayList<String> tuple = null;  
        for (int i = 0; i < datas.size(); i++) {  
            tuple = datas.get(i);  
            c = tuple.get(tuple.size() - 1);  
            if (classes.containsKey(c)) {  
                classes.put(c, classes.get(c) + 1);  
            } else {  
                classes.put(c, 1);  
            }  
        }  
        return classes;  
    }  
      
    /** 
     * 获取具有最大计数的类名,即求多数类 
     * @param classes 类的键值集合 
     * @return 多数类的类名 
     */  
    public String maxClass(Map<String, Integer> classes){  
        String maxC = "";  
        int max = -1;  
        Iterator iter = classes.entrySet().iterator();  
        for(int i = 0; iter.hasNext(); i++)  
        {  
            Map.Entry entry = (Map.Entry) iter.next();   
            String key = (String)entry.getKey();  
            Integer val = (Integer) entry.getValue();   
            if(val > max){  
                max = val;  
                maxC = key;  
            }  
        }  
        return maxC;  
    }  
      
    /** 
     * 构造决策树 
     * @param datas 训练元组集合 
     * @param attrList 候选属性集合 
     * @return 决策树根结点 
     */  
    public TreeNode buildTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList){  
      System.out.print("候选属性列表: ");  
      for (int i = 0; i < attrList.size(); i++) {  
          System.out.print(" " + attrList.get(i) + " ");  
      }  
        System.out.println();  
        TreeNode node = new TreeNode();  
        node.setDatas(datas);  
        node.setCandAttr(attrList);  
        Map<String, Integer> classes = classOfDatas(datas);  
        String maxC = maxClass(classes);  
        if (classes.size() == 1 || attrList.size() == 0) {  
            node.setName(maxC);  
            return node;  
        }  
        Gain gain = new Gain(datas, attrList);  
        int bestAttrIndex = gain.bestGainAttrIndex();  
        ArrayList<String> rules = gain.getValues(datas, bestAttrIndex);  
        node.setRule(rules);  
        node.setName(attrList.get(bestAttrIndex));  
        if(rules.size() > 2){ //?此处有待商榷  
            attrList.remove(bestAttrIndex);  
        }  
        for (int i = 0; i < rules.size(); i++) {  
            String rule = rules.get(i);  
            ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex, rule);  
            for (int j = 0; j < di.size(); j++) {  
                di.get(j).remove(bestAttrIndex);  
            }  
            if (di.size() == 0) {  
                TreeNode leafNode = new TreeNode();  
             
  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值