ID3分类决策树算法

简述

对已知D中元组分类所需要的平均信息公式:

Info(D)=mi=1Pilog2(Pi) (1)
1. 平均信息的解释:
平均信息的解释

  1. 对信息所使用的进制的解释:
    对信息所使用的进制的解释

  2. 由此对于公式(1),我们以2为底,表示我们以2进制对信息进行编码,Info(D)表示我们对D中元组进行全部分类时,以2进制为编码表示这些信息所需要的位数。

属性的选择度量

  1. 根据不同的度量方式我们可以选择不同的度量方法,如使用信息增益作为属性选择度量的ID3算法,使用增益率作为属性选择度量的C4.5算法,使用基尼指数作为属性选择度量的CART算法。这几个算法都是使用不同的属性度量的决策树算法。

ID3算法

  1. ID3算法使用信息增益作为属性的选择度量
  2. 信息增益
    Info(D)=mi=1Pilog2(Pi) (1)
    按属性A进行划分后的新的信息需求为:
    InfoA(D)=vj=1(|Dj|/|D|)Info(Dj) (2)
    信息增益:
    Gain(A)=Info(D)InfoA(D)
    总结:
    信息增益告诉我们通过A上的划分我们得到了多少信息。

决策树

sexcolorsize
mreds
mbluem
fbluem
fyellowb

这里写图片描述

算法思想

先用GetDecisionTreeDFS函数利用训练数据训练出决策树,在对测试数据代进决策树进行测试,从而对他们进行分类。
DataSet类:


    package ID3;

    import java.util.ArrayList;

    public class DataSet {
        public ArrayList<String> attrSet;
        public ArrayList<ArrayList<String>> dataRows;
        protected String targetAttribute;
        public DataSet()
        {
            this.attrSet = null;
            this.targetAttribute = null;
            this.dataRows = new ArrayList<ArrayList<String>>();
        }
        public DataSet(ArrayList<String> attrSet, String targetAttribute)
        {
            this.attrSet = new ArrayList<String>();
            this.attrSet = attrSet;
            this.targetAttribute = targetAttribute;
            this.dataRows = new ArrayList<ArrayList<String>>();
        }
        public void AddRow(ArrayList<String> row)
        {
            dataRows.add(row);
        }

    }

Node类:

    package ID3;

    import java.util.ArrayList;

    public class Node {
        public String attrName;//属性名
        public ArrayList<String> rules;//属性规则
        public ArrayList<Node> children;//子节点集合
        public String targetValue;//目标属性值,只有叶子结点才有的

        public Node(String attrName, ArrayList<String> rules)//树枝节点
        {
            this.attrName = attrName;
            this.rules = rules;
            this.children = new ArrayList<Node>();
            this.targetValue = targetValue;
        }

        public Node(String attrName, String targetValue)//构建叶子节点
        {
            this.attrName = attrName;
            this.rules = rules;
            this.children = new ArrayList<Node>();
            this.targetValue = targetValue;
        }
        /**
         * 递归遍历打印树结构
         * @param root 根节点
         * @param spaceCount 缩进的空格数
         * @param rules 父节点规则,即树枝
         */
        public void PrintTree(Node root, int spaceCount, String rules)
        {
            if(root == null)
            {
                return;
            }
            for(int i = 0; i < spaceCount; i++)
            {
                System.out.println(" ");
            }
            if(root.targetValue != null)
            {
                System.out.println((rules != null ? rules+":":"") + root.targetValue + "(leaf)");
            }
            else
            {
                System.out.println((rules != null ? rules+":":"") + root.attrName);
            }
            if(root.children != null && root.children.size() > 0)
            {
                for(int i = 0; i< root.children.size(); i++)
                {
                    PrintTree(root.children.get(i), spaceCount+2, root.rules.get(i));
                }
            }
        }

        public String Test(String... datas)
        {
            if(datas.length != ID3.originalDataSet.attrSet.size())
            {
                System.out.println("数据有误,不完整");
                return "";
            }
            Node node = this;
            while(node != null)
            {
                if(node.targetValue != null)
                    return node.targetValue;
                String attrName = this.attrName;
                int columnIndex = ID3.originalDataSet.attrSet.indexOf(attrName);
                boolean testRight = false;
                for(String rule : node.rules)
                {
                    if(rule.equals(datas[columnIndex]))
                    {
                        node = node.children.get(node.rules.indexOf(rule));
                        testRight = true;
                        break;
                    }
                }
                if(!testRight)
                    break;
            }

            return null;
        }


    }

ID3类:

     package ID3;

    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Map;
    import java.util.Map.Entry;

    public class ID3 {
        public static DataSet originalDataSet;//最初的数据集
        public ID3(DataSet dataset)
        {
            originalDataSet = dataset;
        }
        /**
         * 利用数据集训练出来一个决策树
         * @param dataSet 数据集
         * @return 决策树根节点
         */
        public Node GetDecisionTreeDFS(DataSet dataSet)
        {
            if (dataSet.dataRows == null)
                return null;
            //剩下的都是同一类的,则返回一个叶子节点
            /*for(ArrayList<String> row : dataSet.dataRows)
            {
                System.out.println(row);
            }*/
            if(TargetAttrIsAllSame(dataSet))
                return new Node(dataSet.targetAttribute, dataSet.dataRows.get(0).get(originalDataSet.attrSet.size()));
            //如果属性集为空,则将其归类为数据集中目标属性值最多的哪一个目标属性值

            if(dataSet.attrSet.size() <= 0)
            {
                return new Node(dataSet.targetAttribute, this.GetMajorTargetValue(dataSet));
            }
            //寻找最大的Gain值属性
            String maxGainAttrName = null;
            double maxGain = -1;
            ArrayList<String> rules = new ArrayList<String>();
            for(String attrName : dataSet.attrSet)
            {
                ArrayList<String> tempRules = this.GetAttrRules(dataSet, attrName);
                double gain = GetGain(dataSet, attrName, tempRules);
                if(maxGain < gain)
                {
                    maxGain = gain;
                    maxGainAttrName = attrName; 
                    rules.clear();
                    rules.addAll(tempRules);
                }
            }
                Node node = new Node(maxGainAttrName, rules);//生成一个新节点
                for(int i = 0; i < node.rules.size(); i++)
                {
                    ArrayList<String> newAttrSet = new ArrayList<String>();
                    for(String attr : dataSet.attrSet)
                    {
                        if(attr != maxGainAttrName)
                        {
                            newAttrSet.add(attr);
                        }
                    }
                        //获取新的数据集
                        DataSet newDataSet = FindSpecificDT(dataSet, maxGainAttrName, node.rules.get(i));
                        newDataSet.attrSet = newAttrSet;

                        //递归再继续分类
                        node.children.add(GetDecisionTreeDFS(newDataSet));

                    }

                return node;

        }
        /**
         * 获取属性为attr,而且属性值为value对应的数据集
         * @param dataSet 新生成的数据集
         * @param attr 特定属性
         * @param value 特定的属性值
         * @return 返回数据集
         */
        public DataSet FindSpecificDT(DataSet dataSet, String attr, String value)
        {
            DataSet resultSet = new DataSet(null, originalDataSet.targetAttribute);
            int columIndex = originalDataSet.attrSet.indexOf(attr);
            for(ArrayList<String> row : dataSet.dataRows)
            {
                if(value.equals(row.get(columIndex)))
                {
                    resultSet.AddRow(row);
                }
            }
            return resultSet;
        }
        /**
         * 找出分裂的相应属性rules
         * @param dataSet 数据集
         * @param attrName 属性名
         * @return
         */
        public ArrayList<String> GetAttrRules(DataSet dataSet, String attrName)
        {
            ArrayList<String> result = new ArrayList<String>();
            int columIndex = dataSet.attrSet.indexOf(attrName);
            for(ArrayList<String> row : dataSet.dataRows)
            {
                String value = row.get(columIndex);
                if(!result.contains(value))
                    result.add(value);
            }
            return result;
        }
        /**
         * 返回信息增益
         * @param dataSet 数据集
         * @param attrName 属性名
         * @param rules
         * @return
         */
        public double GetGain(DataSet dataSet, String attrName, ArrayList<String> rules)
        {
            return getEntropy(dataSet, null, null) - getEntropy(dataSet, attrName, rules);
        }
        /**
         * 计算熵
         * @param dataSet 数据集
         * @param attrName 属性名
         * @param rules 分裂准则
         * @return
         */
        public double getEntropy(DataSet dataSet, String attrName, ArrayList<String> rules)
        {
            if(attrName == null)
            {
                Map<String, Integer> map = GetEachTargetValue(dataSet);
                return  CalculateEntropy(map);
            }
            else{
                double result = 0.0;
                for(int i = 0; i < rules.size(); i++)
                {
                    Map<String, Integer> map = GetEachTargetValue(dataSet, attrName, rules.get(i));
                    double entroy = CalculateEntropy(map);
                    double sum = 0.0;

                    for(Entry<String, Integer> entry : map.entrySet())
                    {
                        sum += entry.getValue();
                    }
                    double dtSize = dataSet.dataRows.size();
                    result += (double)(sum/dtSize)*entroy;
                }
                return result;
            }

        }
        /**
         * 计算制定属性的属性值的数量
         * @param dataSet 数据集
         * @param attrName 属性名
         * @param rules 属性值
         * @return
         */
        public Map<String, Integer> GetEachTargetValue(DataSet dataSet, String attrName, String value)
        {
            Map<String, Integer> map = new HashMap<String, Integer>();
            int columIndex = dataSet.attrSet.indexOf(attrName);
            for(int i = 0; i < dataSet.dataRows.size(); i++)
            {
                String targetValue = dataSet.dataRows.get(i).get(originalDataSet.attrSet.size());
                if(value.equals(dataSet.dataRows.get(i).get(columIndex)))
                {
                    if(map.containsKey(targetValue))
                    {
                        map.put(targetValue, map.get(targetValue)+1);
                    }
                    else
                    {
                        map.put(targetValue, 1);
                    }
                }
            }
            return map;
        }
        /**
         * 计算熵
         * @param map <目标属性值,个数>
         * @return 熵
         */
        public double CalculateEntropy(Map<String, Integer> map)
        {
            double sum = 0.0;
            for(Entry<String, Integer> entry : map.entrySet())
            {
                sum += entry.getValue();
            }
            double result = 0.0;
            for(Entry<String, Integer> entry : map.entrySet())
            {
                int value = entry.getValue();
                if(value == 0)
                    continue;
                result += -((double)value/sum)*(Math.log((double)value/sum)/Math.log(2.0));
            }
            return result;
        }
        /**
         * 判断数据集里的属性是否是同一个类
         * @param dataSet 数据集
         * @return 返回结果
         */
        public boolean TargetAttrIsAllSame(DataSet dataSet)
        {
            String tempValue = null;
            for(ArrayList<String> row : dataSet.dataRows)
            {
                String value = row.get(originalDataSet.attrSet.size());
                if(tempValue == null)
                {
                    tempValue = value;
                    continue;
                }
                if(!tempValue.equals(value))
                {
                    return false;
                }
            }
            return true;
        }
        /**
         * 根据
         * @param dataSet 数据集
         * @return 返回
         */
        public Map<String, Integer> GetEachTargetValue(DataSet dataSet)
        {
            Map<String, Integer> map = new HashMap<String, Integer>();
            for(int i = 0; i < dataSet.dataRows.size(); i++)
            {
                String name = dataSet.dataRows.get(i).get(dataSet.attrSet.size());
                if(map.containsKey(name))
                {
                    map.put(name, map.get(name)+1);
                }
                else
                {
                    map.put(name, 1);
                }
            }
            return map;
        }
        /**
         * 找出目标属性值数最多的属性值
         * @param dataSet 数据集
         * @return 
         */
        public String GetMajorTargetValue(DataSet dataSet)
        {
            String maxTargetValue = null;
            int maxCount = -1;
            Map<String, Integer> map = this.GetEachTargetValue(dataSet);
            for(Entry<String, Integer> entry : map.entrySet())
            {
                if(entry.getValue() > maxCount)
                    maxTargetValue = entry.getKey();
            }
            return maxTargetValue;
        }

    }

Main类:

    package ID3;

    import java.util.ArrayList;
    import java.util.Map;
    import java.util.Map.Entry;

    public class Main {

        public static void main(String[] args) {
            // TODO Auto-generated method stub
            ArrayList<String> l1 = new ArrayList<String>();
            l1.add("m");
            l1.add("red");
            l1.add("s");
            ArrayList<String> l2 = new ArrayList<String>();
            l2.add("m");
            l2.add("blue");
            l2.add("m");
            ArrayList<String> l3 = new ArrayList<String>();
            l3.add("f");
            l3.add("blue");
            l3.add("m");
            ArrayList<String> l4 = new ArrayList<String>();
            l4.add("f");
            l4.add("yellow");
            l4.add("b");
            ArrayList<String> l5 = new ArrayList<String>();
            l5.add("m");
            l5.add("blue");
            l5.add("s");
            ArrayList<ArrayList<String>> l = new ArrayList<ArrayList<String>>();
            l.add(l1);
            l.add(l2);
            l.add(l3);
            l.add(l4);
            l.add(l5);
            ArrayList<String> attrSet = new ArrayList<String>();
            attrSet.add("sex");
            attrSet.add("color");
            DataSet dataSet = new DataSet(attrSet, "size");
            dataSet.AddRow(l1);
            dataSet.AddRow(l2);
            dataSet.AddRow(l3);
            dataSet.AddRow(l4);
            dataSet.AddRow(l5);

            ID3 ID = new ID3(dataSet);//生成决策树

            Node node = ID.GetDecisionTreeDFS(dataSet);//返回决策树根节点
            node.PrintTree(node, 0, null);
            String[] datas  = {"m", "blue"};//测试数据
            System.out.println("reult: " + node.Test(datas));
            DataSet dataset = ID.FindSpecificDT(dataSet, "color", "blue");

        }

    }

这里写图片描述
这里写图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值