ID3算法思想以及实现

1. 决策树原理

数据挖掘中的分类主要包括基于决策树的分类、基于规则的分类、基于神经网络的分类、基于支持向量机的分类、基于朴素贝叶斯的分类等。机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。

[1]Hunt算法

Hunt算法是许多决策树算法的基础,包括ID3、C4.5和CART。

[2]ID3算法

ID3算法的核心是在决策树各级结点上选择属性时,用信息增益[information gain]作为属性的选择标准,以使得在每一个非叶结点进行测试时,能获得关于被测试记录最大的类别信息。

[3]C4.5算法

C4.5算法继承了ID3[Iterative Dichotomiser 3]算法的优点,并在以下几个方面对ID3算法进行了改进:

  • 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足。
  • 在树构造过程中进行剪枝。
  • 能够完成对连续属性的离散化处理。
  • 能够对不完整数据进行处理。

[4]CART算法

CART[Classification And Regression Tree]算法采用一种二分递归分割的技术,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。采用一种二分递归分割的技术,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

[5]SLIQ算法

SLIQ[super-vised learning in quest]算法对C4.5决策树分类算法的实现方法进行了改进,在决策树的构造过程中采用了“预排序”和“广度优先策略”两种技术。

[6]SPRINT算法

为了减少驻留于内存的数据量,SPRINT算法[scalable parallelizable induction of decision trees]进一步改进了决策树算法的数据结构,去掉了在SLIQ中需要驻留于内存的类别列表,将它的类别列合并到每个属性列表中。这样,在遍历每个属性列表寻找当前结点的最优分裂标准时,不必参照其他信息,将对结点的分裂表现在对属性列表的分裂,即将每个属性列表分成两个,分别存放属于各个结点的记录。

总结:除此之外,常见的决策树算法还有CHAID、Quest和C5.0等。ID3、C4.5和CART都采用贪心[即非回溯的]方法,其中决策树以自顶向下递归的分治方式构造。大多数决策树归纳算法都沿用这种自顶向下方法,从训练元组集和它们相关联的类标号开始构造决策树。随着树的构建,训练集递归地划分成较小的子集。

2. 实验数据weather.nominal.arff

@relation weather.symbolic

@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}

@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

3. Weka实现

[1]Preprocess选项

[2]Classify选项

[3]Classifier output选项

=== Run information ===

Scheme:weka.classifiers.trees.Id3
Relation:     weather.symbolic
Instances:    14
Attributes:   5
              outlook
              temperature
              humidity
              windy
              play
Test mode:10-fold cross-validation

=== Classifier model (full training set) ===

Id3


outlook = sunny
|  humidity = high: no
|  humidity = normal: yes
outlook = overcast: yes
outlook = rainy
|  windy = TRUE: no
|  windy = FALSE: yes

Time taken to build model: 0.01 seconds

=== Stratified cross-validation ===
=== Summary ===

Correctly Classified Instances          12               85.7143 %
Incorrectly Classified Instances         2               14.2857 %
Kappa statistic                          0.6889
Mean absolute error                      0.1429
Root mean squared error                  0.378
Relative absolute error                 30      %
Root relative squared error             76.6097 %
Total Number of Instances               14     

=== Detailed Accuracy By Class ===

               TP Rate   FP Rate   Precision   Recall  F-Measure   ROC Area  Class
                 0.889     0.2        0.889     0.889     0.889      0.844    yes
                 0.8       0.111      0.8       0.8       0.8        0.844    no
Weighted Avg.    0.857     0.168      0.857     0.857     0.857      0.844

=== Confusion Matrix ===

 a b   <-- classified as
 8 1 | a = yes
 1 4 | b = no

解析:

[1]统计量

  • Kappa statistic:Kappa统计
  • Mean absolute error:平均绝对误差                    
  • Root mean squared error:均方根误差               
  • Relative absolute error:相对绝对误差               
  • Root relative squared error:相对平方根误差

[2]相关术语

  • TP(true positive):正确的肯定
  • TN(true negative):正确的否定
  • FP(false positive):错误的肯定
  • FN(false negative):错误的否定
  • Precision:精确率  
  • Recall:反馈率
  • ROC(receiver operating characteristic):接受者操作特性
  • F-Measure(F-Score):F值
  • Confusion Matrix:混淆矩阵

4. 信息熵概念

解析:信息熵方程,如下所示:

Entropy = 系统的凌乱程度,使用算法ID3,C4.5和C5.0生成树算法使用熵,这一度量是基于信息学理论中熵的概念。

5. ID3决策树算法伪代码

算法:Generate_decision_tree(samples, attribute)。由给定的训练数据产生一棵判定树。

输入:训练样本samples,由离散值属性表示;候选属性的集合attribute_list。

输出:一棵判定树。

方法:

Generate_decision_tree(samples, attribute_list)

(1)创建结点N;

(2)if samples都在同一个类C then                             // 类标号属性的值均为C,其候选属性值不考虑

(3)return N作为叶结点,以类C标记;

(4)if attribut_list为空 then      

(5)return N作为叶结点,标记为samples中最普通的类;            // 类标号属性值数量最大的那个

(6)选择attribute_list中具有最高信息增益的属性best_attribute; // 找出最好的划分属性

(7)标记结点N为best_attribute;

(8)for each best_attribute中的未知值ai                     // 将样本samples按照best_attribute进行划分

(9)由结点N长出一个条件为best_attribute = ai的分枝;

(10)设si是samples中best_attribute = ai的样本的集合;        // 一个分区

(11)if si为空 then

(12)加上一个树叶,标记为samples中最普通的类;                 // 从样本中找出类标号数量最多的,作为此节点的标记  

(13)else加上一个由Generate_decision_tree(si, attribute_list–best_attribute)返回的结点;  
                                                         // 对数据子集si递归调用,此时候选属性已删除best_attribute

6. 计算过程

[1]在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:

[2]下面我们计算当已知变量outlook的值时,信息熵为多少。

  • outlook = sunny时,2/5的概率打球,3/5的概率不打球,则entropy = 0.971

  • outlook = overcast时,4/4的概率打球,0/4的概率不打球,则entropy = 0

  • outlook = rainy时,3/5的概率打球,2/5的概率不打球,则entropy = 0.971

根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693。信息增溢gain(outlook)为0.940 - 0.693 = 0.247,故系统熵就从0.940下降到了0.693。同理,gain(temperature) = 0.029,gain(humidity) = 0.152,gain(windy) = 0.048。gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

[3]确定N1取temperature、humidity还是windy?

在已知outlook = sunny的情况,根据历史数据,我们作出一张类似上表的表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。其余计算,依次类推。

7. 代码实现

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
 
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
 
public class ID3 {
    private ArrayList<String> attribute = new ArrayList<String>();                            // 存储属性的名称
    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
    private ArrayList<String[]> data = new ArrayList<String[]>();;                            // 原始数据
    int decatt;                                                                               // 决策变量在属性集中的索引
    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
 
    Document xmldoc;
    Element root;
 
    public ID3() {
        xmldoc = DocumentHelper.createDocument();
        root = xmldoc.addElement("root");
        root.addElement("DecisionTree").addAttribute("value", "null");
    }
 
    public static void main(String[] args) {
        ID3 inst = new ID3();
        inst.readARFF(new File("/home/wss/weather.nominal.arff"));
        inst.setDec("play");
        LinkedList<Integer> ll=new LinkedList<Integer>();
        for(int i=0;i<inst.attribute.size();i++){
            if(i!=inst.decatt)
                ll.add(i);
        }
        ArrayList<Integer> al=new ArrayList<Integer>();
        for(int i=0;i<inst.data.size();i++){
            al.add(i);
        }
        inst.buildDT("DecisionTree", "null", al, ll);
        inst.writeXML("/home/wss/result.xml");
        return;
    }
 
    // 读取arff文件,给attribute、attributevalue、data赋值
    public void readARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
                Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                    attribute.add(matcher.group(1).trim());
                    String[] values = matcher.group(2).split(",");
                    ArrayList<String> al = new ArrayList<String>(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    attributevalue.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        data.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e1) {
            e1.printStackTrace();
        }
    }
 
    // 设置决策变量
    public void setDec(int n) {
        if (n < 0 || n >= attribute.size()) {
            System.err.println("决策变量指定错误。");
            System.exit(2);
        }
        decatt = n;
    }
    public void setDec(String name) {
        int n = attribute.indexOf(name);
        setDec(n);
    }
 
    // 给一个样本(数组中是各种情况的计数),计算它的熵
    public double getEntropy(int[] arr) {
        double entropy = 0.0;
        int sum = 0;
        for (int i = 0; i < arr.length; i++) {
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
            sum += arr[i];
        }
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
        entropy /= sum;
        return entropy;
    }
 
    // 给一个样本数组及样本的算术和,计算它的熵
    public double getEntropy(int[] arr, int sum) {
        double entropy = 0.0;
        for (int i = 0; i < arr.length; i++) {
            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
        }
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
        entropy /= sum;
        return entropy;
    }
 
    public boolean infoPure(ArrayList<Integer> subset) {
        String value = data.get(subset.get(0))[decatt];
        for (int i = 1; i < subset.size(); i++) {
            String next=data.get(subset.get(i))[decatt];
            // equals表示对象内容相同,==表示两个对象指向的是同一片内存
            if (!value.equals(next))
                return false;
        }
        return true;
    }
 
    // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
    public double calNodeEntropy(ArrayList<Integer> subset, int index) {
        int sum = subset.size();
        double entropy = 0.0;
        int[][] info = new int[attributevalue.get(index).size()][];
        for (int i = 0; i < info.length; i++)
            info[i] = new int[attributevalue.get(decatt).size()];
        int[] count = new int[attributevalue.get(index).size()];
        for (int i = 0; i < sum; i++) {
            int n = subset.get(i);
            String nodevalue = data.get(n)[index];
            int nodeind = attributevalue.get(index).indexOf(nodevalue);
            count[nodeind]++;
            String decvalue = data.get(n)[decatt];
            int decind = attributevalue.get(decatt).indexOf(decvalue);
            info[nodeind][decind]++;
        }
        for (int i = 0; i < info.length; i++) {
            entropy += getEntropy(info[i]) * count[i] / sum;
        }
        return entropy;
    }
 
    // 构建决策树
    public void buildDT(String name, String value, ArrayList<Integer> subset,
            LinkedList<Integer> selatt) {
        Element ele = null;
        @SuppressWarnings("unchecked")
        List<Element> list = root.selectNodes("//"+name);
        Iterator<Element> iter=list.iterator();
        while(iter.hasNext()){
            ele=iter.next();
            if(ele.attributeValue("value").equals(value))
                break;
        }
        if (infoPure(subset)) {
            ele.setText(data.get(subset.get(0))[decatt]);
            return;
        }
        int minIndex = -1;
        double minEntropy = Double.MAX_VALUE;
        for (int i = 0; i < selatt.size(); i++) {
            if (i == decatt)
                continue;
            double entropy = calNodeEntropy(subset, selatt.get(i));
            if (entropy < minEntropy) {
                minIndex = selatt.get(i);
                minEntropy = entropy;
            }
        }
        String nodeName = attribute.get(minIndex);
        selatt.remove(new Integer(minIndex));
        ArrayList<String> attvalues = attributevalue.get(minIndex);
        for (String val : attvalues) {
            ele.addElement(nodeName).addAttribute("value", val);
            ArrayList<Integer> al = new ArrayList<Integer>();
            for (int i = 0; i < subset.size(); i++) {
                if (data.get(subset.get(i))[minIndex].equals(val)) {
                    al.add(subset.get(i));
                }
            }
            buildDT(nodeName, val, al, selatt);
        }
    }
 
    // 把xml写入文件
    public void writeXML(String filename) {
        try {
            File file = new File(filename);
            if (!file.exists())
                file.createNewFile();
            FileWriter fw = new FileWriter(file);
            OutputFormat format = OutputFormat.createPrettyPrint();
            XMLWriter output = new XMLWriter(fw, format);
            output.write(xmldoc);
            output.close();
        } catch (IOException e) {
            System.out.println(e.getMessage());
        }
    }
}

结果输出,如下所示:

<?xml version="1.0" encoding="UTF-8"?>
 
<root>
	<DecisionTree value="null">
        	<outlook value="sunny">
      			<humidity value="high">no</humidity>
      			<humidity value="normal">yes</humidity>
    		</outlook>
    		<outlook value="overcast">yes</outlook>
    		<outlook value="rainy">
      			<windy value="TRUE">no</windy>
      			<windy value="FALSE">yes</windy>
    		</outlook>
  	</DecisionTree>
</root>

结果输出,如下所示:

参考文献:

[1] 决策树:http://zh.wikipedia.org/zh-cn/%E5%86%B3%E7%AD%96%E6%A0%91

[2] 《数据挖掘:概念与技术》

[3] 决策树分类算法:如流,新一代智能工作平台

[4] 归纳决策树ID3:归纳决策树ID3(Java实现) - 张朝阳 - 博客园

[5] ID3决策树算法伪代码及注解:ID3 决策树算法伪代码及注解_数据挖掘者DATAMiner-CSDN博客_id3算法伪代码

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

NLP工程化

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值