数据挖掘 -- CART决策树算法

1. 算法原理

CART算法: 为二叉决策树, 能够同时处理离散属性和连续属性,计算所有属性的Gini值,根据最大Gini值进行分裂生成决策树。(离散属性可以用二进制枚举分成两部分, 连续值根据值进行排序后遍历的时候同时更新类别判断值的状态,复杂度降到o(n))

2. 代码实现

Node.java

package com.clxk1997.model;

/**
 * @Description 决策树节点
 * @Author Clxk
 * @Date 2019/5/5 15:53
 * @Version 1.0
 */
public class Node {

    private String field;
    private String value;

   	constructor...
   	getters ans setters....
}

Field.java

package com.clxk1997.model;

/**
 * @Description 属性标签
 * @Author Clxk
 * @Date 2019/5/5 14:45
 * @Version 1.0
 */
public class Field {

    private String fieldName;
    private int fieldClass;
	
	constructor...
   	getters ans setters....
}

DecisionTree.java

package com.clxk1997.model;

import com.clxk1997.utils.Fields;

import java.util.ArrayList;
import java.util.List;

/**
 * @Description 决策树
 * @Author Clxk
 * @Date 2019/5/5 15:59
 * @Version 1.0
 */
public class DecisionTree {

    private Node node;
    private DecisionTree lTree, rTree;

    constructor...
   	getters ans setters....

    public static DecisionTree initTree() {
        DecisionTree root = new DecisionTree();
        root.setNode(new Node("root","root"));
        root.setlTree(null);
        root.setrTree(null);
        return root;
    }

    public static void print(DecisionTree root) {
        if(root == null) return;
        System.out.println("属性: " + root.node.getField() + "  " + "判断: " + root.getNode().getValue());
        print(root.lTree);
        print(root.rTree);
    }

    public static String search(DecisionTree tree, ArrayList<String> list, List<Field> fields_copy) {
        if(tree.getNode().getField().equals("leaf")) return tree.getNode().getValue();
        if(tree.getlTree().getNode().getField().equals("leaf")) return tree.getlTree().getNode().getValue();
        DecisionTree lTree = tree.getlTree();
        DecisionTree rTree = tree.getrTree();
        String lfield = lTree.getNode().getField();
        String lvalue = lTree.getNode().getValue();
        int idx = 0;
        boolean isContinous = true;
        for(int i = 0; i < fields_copy.size(); i++) {
           if(fields_copy.get(i).getFieldName().equals(lfield)) {
               if(fields_copy.get(i).getFieldClass() == Fields.CONTINOUS) {
                   isContinous = true;
               } else {
                   isContinous = false;
               }
               idx = i;
               break;
           }
        }
        if(isContinous) {
            double value = Double.valueOf(lvalue.substring(2));
            if(Double.valueOf(list.get(idx)) <= value) {
                return search(lTree, list, fields_copy);
            } else {
                return search(rTree, list, fields_copy);
            }
        } else {
            String value = "|" + list.get(idx) + "|";
            if(lvalue.indexOf(value) > -1) {
                return search(lTree, list, fields_copy);
            } else {
                return search(rTree, list, fields_copy);
            }
        }
    }
}

Fields.java

package com.clxk1997.utils;

/**
 * @Description 属性标签
 * @Author Clxk
 * @Date 2019/5/5 14:47
 * @Version 1.0
 */
public final class Fields {

    public static final int ID = 1;
    public static final int CONTINOUS = 2;
    public static final int DISCRETE = 3;
    public static final int CLASS = 4;
}

Cart.java

package com.clxk1997.cart;

import com.clxk1997.model.DecisionTree;
import com.clxk1997.model.Field;
import com.clxk1997.model.Node;
import com.clxk1997.utils.Fields;

import java.io.*;
import java.lang.reflect.Array;
import java.util.*;

/**
 * @Description Cart算法实现
 * @Author Clxk
 * @Date 2019/5/5 14:42
 * @Version 1.0
 */
public class Cart {

    //属性
    private static List<Field> fields = new ArrayList<>();
    private static List<Field> fields_copy = new ArrayList<>();
    //训练集
    private static List<ArrayList<String>> trains = new ArrayList<>();
    //测试集
    private static List<ArrayList<String>> tests = new ArrayList<>();
    //决策树
    private static DecisionTree tree = DecisionTree.initTree();


    public static void main(String[] args) {

        try {
            inputModel();

            training();

            inputTest();

            OutputResult();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 输入训练集
     */
    public static void inputModel() throws FileNotFoundException {

        inputFields();

        inputTrains();
    }

    /**
     * 输入训练集
     */
    public static void inputTrains() throws FileNotFoundException {
        System.out.println("请输入训练集:(\"exit\"结束)");
        Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_trains.txt"));
        String strLine;
        while(true) {
            strLine = scan.nextLine();
            if(strLine.equals("exit"))break;
            String[] split = strLine.split("\\s");
            ArrayList<String> train = new ArrayList<>();
            for(String s: split) {
                train.add(s);
            }
            trains.add(train);
        }
    }

    /**
     * 输入属性
     */
    public static void inputFields() throws FileNotFoundException {
        System.out.println("请输入属性数量:");
        Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_fields.txt"));
        int fieldCnt = scan.nextInt();

        scan.nextLine();
        for(int i = 0; i < fieldCnt; i++) {
            System.out.println("请输入第" + (i+1) + "个属性及属性类型(ID-1 | CONTINOUS-2 | DISCRETE-3 | CLASS-4)");
            String strLine = scan.nextLine();
            String[] fieldStr = strLine.split("\\s");
            int fieldClass = Integer.valueOf(fieldStr[1]);
            Field field = new Field(fieldStr[0], fieldClass);
            fields.add(field);
            fields_copy.add(field);
        }
    }

    /**
     * 训练
     */
    public static void training() {

        train(fields, trains, tree);

        System.out.println("===================训练完成====================");
    }

    public static void train(List<Field> fields, List<ArrayList<String>> trains, DecisionTree tree) {
        if(trains.size() == 0) return;
        if(isPure(trains)) {
            tree.setlTree(new DecisionTree(new Node("leaf", trains.get(0).get(trains.get(0).size()-1)), null, null));
            tree.setrTree(new DecisionTree(new Node("leaf", trains.get(0).get(trains.get(0).size()-1)), null, null));
            return;
        }
        String conflictAns = isConflict(trains);
        if(conflictAns != "") {
            tree.setlTree(new DecisionTree(new Node("leaf", conflictAns), null, null));
            tree.setrTree(new DecisionTree(new Node("leaf", conflictAns), null, null));
            return;
        }

        double classGini = getClassGini(trains);
        System.out.println("类属性的Gini值为: " + classGini);
        int maxIndex = 0;
        double maxGini = -100000;
        double ansResult = 0;
        boolean isContinous = true;
        String[] keys = null;
        Object[] result = new Object[3];
        for (int i = 1; i < fields.size() - 1; i++) {
            if(fields.get(i).getFieldClass() == Fields.ID) continue;
            if(fields.get(i).getFieldClass() == Fields.CLASS) continue;
            if(fields.get(i).getFieldClass() == Fields.CONTINOUS) {
                result = getContinousGini(trains, i, classGini);
                System.out.println("Gini值计算:     " + "属性: " + fields.get(i).getFieldName() + "  Gini值: " + (double)result[0]);
                if((double)result[0] > maxGini) {
                    maxGini = (double)result[0];
                    ansResult = (double)result[1];
                    isContinous = true;
                    maxIndex = i;
                }
            } else {
                result = getDiscreteGini(trains, i, classGini);
                System.out.println("Gini值计算:     " + "属性: " + fields.get(i).getFieldName() + "  Gini值: " + (double)result[0]);
                if((double)result[0] > maxGini) {
                    maxGini = (double)result[0];
                    ansResult = (double)result[1];
                    keys = (String[]) result[2];
                    isContinous = false;
                    maxIndex = i;
                }
            }
        }
        //如果Gini值最大的节点是连续节点
        if(fields.size() > 2 && isContinous) {
            System.out.println("最大Gini值:     " + "属性: " + fields.get(maxIndex).getFieldName() + "  分隔条件: " + "<= " + String.valueOf(ansResult) + " and " + "> " + String.valueOf(ansResult));
            Node lnode = new Node(fields.get(maxIndex).getFieldName(), "<=" + String.valueOf(ansResult));
            Node rnode = new Node(fields.get(maxIndex).getFieldName(), ">" + String.valueOf(ansResult));
            DecisionTree lTree = new DecisionTree(lnode, null, null);
            DecisionTree rTree = new DecisionTree(rnode, null, null);
            tree.setrTree(rTree);
            tree.setlTree(lTree);
            ArrayList<ArrayList<String>> llist = new ArrayList<>();
            ArrayList<ArrayList<String>> rlist = new ArrayList<>();
            ArrayList<Field> fields_copy1 = new ArrayList<>();
            ArrayList<Field> fields_copy2 = new ArrayList<>();
            for(Field field : fields) {
                fields_copy1.add(field);
                fields_copy2.add(field);
            }
            fields_copy1.remove(maxIndex);
            fields_copy2.remove(maxIndex);
            for(ArrayList<String> list : trains) {
                if(Double.valueOf(list.get(maxIndex)) <= ansResult) {
                    list.remove(maxIndex);
                    llist.add(list);
                } else {
                    list.remove(maxIndex);
                    rlist.add(list);
                }
            }
            train(fields_copy1, llist, lTree);
            train(fields_copy2, rlist, rTree);
        } else if(fields.size() > 2) { //Gini值最大的节点是离散节点
            String lfield = "";
            String rfield = "";
            for(int i = 0; i < keys.length; i++) {
                if(((int)ansResult & (1 << i)) != 0) {
                    rfield += "|" + keys[i] + "|";
                } else {
                    lfield += "|" + keys[i] + "|";
                }
            }
            System.out.println("最大Gini值:     " + "属性: " + fields.get(maxIndex).getFieldName() + "  分隔条件: " + lfield + " and " + rfield);
            Node lnode = new Node(fields.get(maxIndex).getFieldName(), lfield);
            Node rnode = new Node(fields.get(maxIndex).getFieldName(), rfield);
            DecisionTree lTree = new DecisionTree(lnode, null, null);
            DecisionTree rTree = new DecisionTree(rnode, null, null);
            tree.setlTree(lTree);
            tree.setrTree(rTree);
            ArrayList<ArrayList<String>> llist = new ArrayList<>();
            ArrayList<ArrayList<String>> rlist = new ArrayList<>();
            ArrayList<Field> fields_copy1 = new ArrayList<>();
            ArrayList<Field> fields_copy2 = new ArrayList<>();
            for(Field field : fields) {
                fields_copy1.add(field);
                fields_copy2.add(field);
            }
            fields_copy1.remove(maxIndex);
            fields_copy2.remove(maxIndex);
            for(ArrayList<String> list : trains) {
                if(lfield.indexOf("|" + list.get(maxIndex) + "|") > -1) {
                    list.remove(maxIndex);
                    llist.add(list);
                } else {
                    list.remove(maxIndex);
                    rlist.add(list);
                }
            }
            train(fields_copy1, llist, lTree);
            train(fields_copy2, rlist, rTree);
        }
    }

    /**
     * 混乱节点(普通属性相同类属性不相同)
     * @param trains
     * @return
     */
    public static String isConflict(List<ArrayList<String>> trains) {
        int len = trains.get(0).size();
        for(int i = 1; i < len-1; i++) {
            Set<String> set = new HashSet<>();
            for(ArrayList<String> train : trains) {
                set.add(train.get(i));
                if(set.size() > 1) return "";
            }
        }
        int lcnt = 0, rcnt = 0;
        String lcur = trains.get(0).get(trains.get(0).size()-1);
        String rcur = "";
        for(ArrayList<String> train : trains) {
            if(train.get(train.size()-1).equals(lcur)) lcnt++;
            else {
                rcnt++;
                rcur = train.get(train.size()-1);
            }
        }
        if(lcnt >= rcnt) return lcur;
        else return rcur;
    }

    /**
     * 判断当前节点是不是纯节点
     * @param trains
     * @return
     */
    public static boolean isPure(List<ArrayList<String>> trains) {
        Set<String> set = new HashSet<>();
        for(ArrayList<String> list : trains) {
            set.add(list.get(list.size()-1));
            if(set.size() > 1) return false;
        }
        return true;
    }

    /**
     * 计算连续值Gini系数
     * @param trains
     * @param idx
     * @param classGini
     * @return
     */
    public static Object[] getContinousGini(List<ArrayList<String>> trains, int idx, double classGini) {

        String classValue = trains.get(0).get(trains.get(0).size()-1);
        trains.sort(new Comparator<ArrayList<String>>() {
            @Override
            public int compare(ArrayList<String> o1, ArrayList<String> o2) {
                double a = Double.valueOf(o1.get(idx));
                double b = Double.valueOf(o2.get(idx));
                if(a > b) return 1;
                else if(a == b) return 0;
                else return -1;
            }
        });
        int lcnt0 = 0, lcnt1 = 0, rcnt0 = 0, rcnt1 = 0;
        double maxgini = 0;
        double ansValue = 0;
        String curStr = trains.get(0).get(trains.get(0).size()-1);
        //初始化lcnt, rcnt
        for(ArrayList<String> train : trains) {
            if(train.get(train.size()-1).equals(curStr)) {
                rcnt0++;
            }
            rcnt1++;
        }
        for(int i = 0; i < trains.size()-1; i++) {
            double curt = (Double.valueOf(trains.get(i).get(idx)) + Double.valueOf(trains.get(i+1).get(idx))) / 2.0;
            lcnt1++;
            rcnt1--;
            if(trains.get(i).get(trains.get(i).size()-1).equals(curStr)) {
                rcnt0--;
                lcnt0++;
            }
            /*for (int j = 0; j < trains.size(); j++) {
                if(Double.valueOf(trains.get(j).get(idx)) > curt) break;
                lcnt1++;
                if(trains.get(j).get(trains.get(j).size()-1).equals(classValue)) {
                    lcnt0++;
                }
            }
            for (int j = trains.size()-1; j >= 0; j--) {
                if(Double.valueOf(trains.get(j).get(idx)) <= curt) break;
                rcnt1++;
                if(trains.get(j).get(trains.get(j).size()-1).equals(classValue)) {
                    rcnt0++;
                }
            }*/
            double cur = calculDeta(trains.size(), classGini, lcnt0, lcnt1, rcnt0, rcnt1);
            if(cur > maxgini) {
                maxgini = cur;
                ansValue = curt;
            }
        }
        Object[] ans = new Object[2];
        ans[0] = maxgini;
        ans[1] = ansValue;
        return ans;
    }

    /**
     * 计算离散值Gini系数
     * @param trains
     * @param idx
     * @param classGini
     * @return
     */
    public static Object[] getDiscreteGini(List<ArrayList<String>> trains, int idx, double classGini) {

        String classValue = trains.get(0).get(trains.get(0).size()-1);
        Map<String, Integer> fieldmp = new HashMap<>();
        Map<String, Map<String, Integer>> classmp = new HashMap<>();
        for(int i = 0; i < trains.size(); i++) {
            ArrayList<String> cur = trains.get(i);
            String key = cur.get(idx);
            String classKey = cur.get(cur.size()-1);
            if(fieldmp.containsKey(key)) {
                fieldmp.put(key, fieldmp.get(key)+1);
                if(classmp.get(key).containsKey(classKey)) {
                    classmp.get(key).put(classKey, classmp.get(key).get(classKey) + 1);
                } else {
                    classmp.get(key).put(classKey, 1);
                }
            } else {
                fieldmp.put(key, 1);
                Map<String, Integer> cd = new HashMap<>();
                cd.put(classKey, 1);
                classmp.put(key, cd);
            }
        }
        String[] keys = new String[classmp.size()];
        classmp.keySet().toArray(keys);
        int lcnt0 = 0, lcnt1 = 0, rcnt0 = 0, rcnt1 = 0;
        double maxgini = 0;
        double ansValue = 0;
        if(fieldmp.size() != 1) {
            for(int i = 1; i < (1 << fieldmp.size())-1; i++) {
                lcnt0 = lcnt1 = rcnt1 = rcnt0 = 0;
                for (int j = 0; j < fieldmp.size(); j++) {
                    if((i&(1 << j)) != 0) {
                        if(classmp.get(keys[j]).containsKey(classValue))
                            lcnt0 += classmp.get(keys[j]).get(classValue);
                        lcnt1 += fieldmp.get(keys[j]);
                    } else {
                        if(classmp.get(keys[j]).containsKey(classValue))
                            rcnt0 += classmp.get(keys[j]).get(classValue);
                        rcnt1 += fieldmp.get(keys[j]);
                    }
                }
                double cur = calculDeta(trains.size(), classGini, lcnt0, lcnt1, rcnt0, rcnt1);
                if(cur > maxgini) {
                    maxgini = cur;
                    ansValue = i;
                }
            }
        }
        Object[] ans = new Object[3];
        ans[0] = maxgini;
        ans[1] = ansValue;
        ans[2] = keys;
        return ans;
    }

    /**
     * 计算deta值
     * @param size
     * @param classGini
     * @param lcnt0
     * @param lcnt1
     * @param rcnt0
     * @param rcnt1
     * @return
     */
    public static double calculDeta(int size, double classGini, int lcnt0, int lcnt1, int rcnt0, int rcnt1) {
        return classGini - ((double)lcnt1 / (double)size) * calculGini(lcnt0, lcnt1) - ((double)rcnt1 / (double)size) * calculGini(rcnt0, rcnt1);
    }

    /**
     * 计算左右节点Gini值
     * @param lcnt0
     * @param lcnt1
     * @return
     */
    public static double calculGini(int lcnt0, int lcnt1) {
        double a = (double)lcnt0 / (double)lcnt1;
        double b = (double)(lcnt1-lcnt0) / (double)lcnt1;
        return 1d - a * a - b * b;
    }

    public static double getClassGini(List<ArrayList<String>> trains) {
        double ans = 1;
        Map<String, Integer> mp = new HashMap<>();
        for(int i = 0; i < trains.size(); i++) {
            String key = trains.get(i).get(trains.get(i).size()-1);
            if(mp.containsKey(key)) mp.put(key, mp.get(key)+1);
            else mp.put(key, 1);
        }
        for(Map.Entry<String,Integer> entry : mp.entrySet()) {
            ans -= ((double)entry.getValue() / (double)trains.size()) * ((double)entry.getValue() / (double)trains.size());
        }
        return ans;
    }

    /**
     * 输入测试集
     */
    public static void inputTest() throws FileNotFoundException {

        System.out.println("请输入测试集:(\"exit\"结束)");
        Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_test.txt"));
        String strLine;
        while(true) {
            strLine = scan.nextLine();
            if(strLine.equals("exit"))break;
            String[] split = strLine.split("\\s");
            ArrayList<String> test = new ArrayList<>();
            for(String s: split) {
                test.add(s);
            }
            tests.add(test);
        }
        System.out.println("===================开始预测====================");

    }

    /**
     * 输出测试结果
     */
    public static void OutputResult() throws Exception {
        int tr = 0;
        DecisionTree.print(tree);
        String s = "";
        FileWriter writer = new FileWriter("D:\\CART\\src\\output_ans.txt",true);
        BufferedWriter bufferedWriter = new BufferedWriter(writer);
        for(ArrayList<String> list : tests) {
            String ans = DecisionTree.search(tree, list, fields_copy);
            bufferedWriter.write(ans + "\n");
        }
        bufferedWriter.close();
        writer.close();
        System.out.println("===================预测完成====================");
    }

}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Chook_lxk

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值