决策树(信息增益)的java实现

1、ResultNode.java:

package com.decision.my;


public class ResultNode {


    /** 属性名 */
    private String name;
    
    /** 属性值 */
    private String value;
    
    /** 类别 */
    private String classfy;


    public String getClassfy() {
        return classfy;
    }


    public void setClassfy(String classfy) {
        this.classfy = classfy;
    }


    public String getName() {
        return name;
    }


    public void setName(String name) {
        this.name = name;
    }


    public String getValue() {
        return value;
    }


    public void setValue(String value) {
        this.value = value;
    }
}


2、TreeNode.java

package com.decision.my;


import java.util.ArrayList;
import java.util.List;


/**
 * <p>本类描述: 树节点</p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-23 下午07:22:21
 */
public class TreeNode {


    /** 树节点的名字 */
    private String name;
    
    /** 树节点的信息熵 */
    private Double value;
    
    /** 树叶节点的分类 */
    private String classify;
    
    /** 树节点的孩子 */
    private List<TreeNode> childs = new ArrayList<TreeNode>();
    
    /** 树节点的父亲 */
    private TreeNode parent;
    
    private List<String> path = new ArrayList<String>();


    public List<String> getPath() {
        return path;
    }


    public void setPath(List<String> path) {
        this.path = path;
    }


    public String getName() {
        return name;
    }


    public void setName(String name) {
        this.name = name;
    }


    public Double getValue() {
        return value;
    }


    public void setValue(Double value) {
        this.value = value;
    }


    public String getClassify() {
        return classify;
    }


    public void setClassify(String classify) {
        this.classify = classify;
    }


    public List<TreeNode> getChilds() {
        return childs;
    }


    public void setChilds(List<TreeNode> childs) {
        this.childs = childs;
    }


    public TreeNode getParent() {
        return parent;
    }


    public void setParent(TreeNode parent) {
        this.parent = parent;
    }
}


3、Decision.java

package com.decision.my;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;




/**
 * <p>本类描述: </p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-23 下午07:21:17
 */
public class Decision {


    /**
     * 方法描述:创建决策树
     * @param D
     * @param attributeList
     * @return
     */
    public static TreeNode createDecisionTree(List<ArrayList<String>> D, List<String> attributeList) {
        TreeNode root = new TreeNode();
        String classfy = pure(D); // 判断D是否是纯的,若不是,返回null;若是,返回类别
        if (classfy != null) {
            root.setClassify(classfy);
            return root;
        } else if (attributeList == null || attributeList.size() == 0) {
            root.setClassify("more");
            return root;
        } else {
             // 得到当前信息增益最大的节点
            Double gainD = getGainD(D);
            Double max = 0D;
            Map<String, Integer> mapAttribute = new HashMap<String, Integer>();
            int ind = 0;
            for (int i = 0; i < attributeList.size(); i++) {
                Map<String, Integer> map = new HashMap<String, Integer>();
                for (ArrayList<String> lists : D) {
                    Integer value = map.get(lists.get(i));
                    map.put(lists.get(i), value == null ? 1 : ++value);
                }
                Double gainA = getGainA(map, D, gainD, i);
                if (gainA > max) {
                    max = gainA;
                    ind = i;
                    mapAttribute = map;
                }
            }
            root.setName(attributeList.get(ind));
            root.setValue(max);
            ArrayList<String> newAttributeList = (ArrayList<String>)((ArrayList<String>)attributeList).clone();
            newAttributeList.remove(ind);
            for (Entry<String, Integer> entry : mapAttribute.entrySet()) {
                List<ArrayList<String>> datas = getData(entry, D, ind);
                TreeNode child = createDecisionTree(datas, newAttributeList);
                child.setParent(root);
                child.getPath().add(entry.getKey());
                root.getChilds().add(child);
            }
        }
        return root;
    }


    /**
     * 方法描述:得到属性对应的数据集
     * @param entry
     * @param d
     * @param ind
     * @return
     */
    private static List<ArrayList<String>> getData(Entry<String, Integer> entry, List<ArrayList<String>> D, int ind) {
        List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
        for (ArrayList<String> lists : D) {
            if (entry.getKey().equals(lists.get(ind))) {
                ArrayList<String> newLists = (ArrayList<String>)lists.clone();
                newLists.remove(ind);
                results.add(newLists);
            }
        }
        return results;
    }


    /**
     * 方法描述:得到每个属性的信息增益
     * @param map
     * @param D
     * @param gainD
     * @param i
     * @return
     */
    private static Double getGainA(Map<String, Integer> map, List<ArrayList<String>> D, Double gainD, int i) {
        Double result = 0D;
        Double info = 0D;
        for (Entry<String, Integer> entry : map.entrySet()) {
            double attributeCount = Double.parseDouble(String.valueOf(entry.getValue()));
            double allCount = Double.parseDouble(String.valueOf(D.size()));
            Map<String, Integer> mapAttribute1 = new HashMap<String, Integer>();
            for (ArrayList<String> lists : D) {
                if (entry.getKey().equals(lists.get(i))) {
                    Integer value = mapAttribute1.get(lists.get(lists.size() - 1));
                    mapAttribute1.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
                }
            }
            Double item = 0D;
            for (Entry<String, Integer> e : mapAttribute1.entrySet()) {
                Double p = Double.parseDouble(String.valueOf(e.getValue())) / attributeCount;
                item += -p * Math.log(p) / Math.log(2);
            }
            info += (attributeCount / allCount) * item;
        }
        result = gainD - info;
        return result;
    }


    /**
     * 方法描述:得到数据集D的信息增益
     * @param D
     * @return
     */
    private static Double getGainD(List<ArrayList<String>> D) {
        Double result = 0D;
        Map<String, Integer> map = new HashMap<String, Integer>();
        for (ArrayList<String> lists : D) {
            Integer value = map.get(lists.get(lists.size() - 1));
            map.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
        }
        for (Entry<String, Integer> entry : map.entrySet()) {
            Double p = Double.parseDouble(String.valueOf(entry.getValue())) / Double.parseDouble(String.valueOf(D.size()));
            result += -p * Math.log(p) / Math.log(2);
        }
        return result;
    }


    /**
     * 方法描述:判断D是否是纯的,若不是,返回null;若是,返回类别
     * @param d
     * @return
     */
    private static String pure(List<ArrayList<String>> D) {
        Set<String> sets = new HashSet<String>();
        for (ArrayList<String> lists : D) {
            sets.add(lists.get(lists.size() - 1));
        }
        Iterator<String> iterator = sets.iterator();
        if (sets.size() == 1) {
            return iterator.next();
        } else {
            return null;
        }
    }
    
    public static void main(String[] args) {
        TreeNode node = createDecisionTree(configData(),configAttribute());
        List<ArrayList<ResultNode>> nodes = getAllPath(node);
        for (int i = 0; i < nodes.size(); i++) {
            System.out.println("第" + (i + 1) + "条路径:");
            ArrayList<ResultNode> lists = nodes.get(i);
            for (int j = 0; j < lists.size() - 1; j++) {
                System.out.print("(" + lists.get(j).getName() + ":" + lists.get(j).getValue() + "),");
            }
            System.out.print(lists.get(lists.size() - 1).getClassfy());
            System.out.println();
        }
    }
    
    /**
     * 方法描述:得到指定树的所有叶子节点
     * @param root
     * @return
     */
    private static List<TreeNode> getLeafs(TreeNode root) {
        List<TreeNode> results = new ArrayList<TreeNode>();
        traverseTree(root, results);
        return results;
    }


    /**
     * 方法描述:递归遍整个树
     * @param node
     */
    private static void traverseTree(TreeNode node, List<TreeNode> results) {
        List<TreeNode> childs = node.getChilds();
        if (childs == null || childs.size() == 0) {
            results.add(node);
        } else {
            for (TreeNode child : childs) {
                traverseTree(child, results);
            }
        }
    }
    
    /**
     * 方法描述:遍历树,产生所有路径
     * @param node
     * @return
     */
    private static List<ArrayList<ResultNode>> getAllPath(TreeNode root) {
        List<ArrayList<ResultNode>> results = new ArrayList<ArrayList<ResultNode>>();
        List<TreeNode> leafs = getLeafs(root);
        for (TreeNode node : leafs) {
            ArrayList<ResultNode> resultNodes = new ArrayList<ResultNode>();
            ResultNode resultNode = new ResultNode();
            resultNode.setClassfy(node.getClassify());
            resultNodes.add(resultNode);
            TreeNode parent = node.getParent();
            TreeNode p = node;
            while (parent != null) {
                ResultNode resultNode1 = new ResultNode();
                resultNode1.setName(parent.getName());
                resultNode1.setValue(p.getPath().get(0));
                resultNodes.add(resultNode1);
                p = parent;
                parent = parent.getParent();
            }
            Collections.reverse(resultNodes);
            results.add(resultNodes);
        }
        return results;
    }


    private static List<String> configAttribute() {
        List<String> results = new ArrayList<String>();
        results.add("age");
        results.add("income");
        results.add("student");
        results.add("credit_rating");
        return results;
    }


    private static List<ArrayList<String>> configData() {
        List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
        try {
            BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(new File("D:/data.txt"))));
            String line = is.readLine();
            while (line != null) {
                String[] split = line.split(",");
                ArrayList<String> s1s = new ArrayList<String>();
                for (int i = 0; i < split.length; i++) {
                    s1s.add(split[i]);
                }
                results.add(s1s);
                line = is.readLine();
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return results;
    }
}


4、测试

测试数据:d盘放一个data.txt,数据如下

youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值