决策树

决策树代码,java实现

import java.util.ArrayList;
import java.util.Random;

class Node {
    int split_feature_index;
    double split_feature_value;
    Node left_child = null;
    Node right_child = null;
    ArrayList<ArrayList<Double>> X;
    ArrayList<Integer> y;

    public Node(ArrayList<ArrayList<Double>> X,
                ArrayList<Integer> y) {
            this.X = X;
            this.y = y;
    }

    private double get_Variance(ArrayList<Integer> x) {
        double avg = 0;
        double variance = 0;
        for(var i : x) {
            avg += i;
            variance += i * i;
        }
        avg /= x.size();
        variance -= avg * avg * x.size();
        return variance;
    }

    public Node (int index, int value) {
        split_feature_index = index;
        split_feature_value = value;
    }

    public void split() {
        var y_variance = get_Variance(y);

        if (X.size() <= 5 || y_variance <= 0.0001) {
            return;
        } else {
            double min_loss = y_variance;
            //System.out.println(min_loss);
            // 枚举所有分点
            for (int index = 0; index < X.get(0).size(); index ++) {
                // 第index个属性
                for (int i = 0; i < X.size(); i++) {
                    // <= a[i]的都分入
                    double left_avg = 0, right_avg = 0;
                    int left_cnt = 0, right_cnt = 0;
                    for (int j = 0; j < X.size(); j++) {
                        if (X.get(j).get(index) <= X.get(i).get(index)) {
                            left_avg += y.get(j);
                            left_cnt ++;
                        } else {
                            right_avg += y.get(j);
                        }
                    }
                    right_cnt = X.size() - left_cnt;
                    left_avg /= left_cnt;
                    right_avg = right_cnt != 0 ?
                            right_avg / right_cnt : 0;
                    //System.out.println(left_avg);
                    //System.out.println(right_avg);
                    double left_se = 0, right_se = 0;
                    for (int j = 0; j < X.size(); j++) {
                        if (X.get(j).get(index) <= X.get(i).get(index)) {
                            left_se += ( y.get(j) - left_avg
                            ) * ( y.get(j) - left_avg);
                        } else {
                            right_se += ( y.get(j) - right_avg
                            ) * ( y.get(j) - right_avg);
                        }
                    }

                    double se_reduce = left_se + right_se;
                    //System.out.println(se_reduce);
                    //System.out.println(X.get(i).get(index));

                    if (se_reduce < min_loss) {
                        min_loss = se_reduce;
                        split_feature_value = X.get(i).get(index);
                        split_feature_index = index;
                    }
                }
            }
            //System.out.println(split_feature_index);
            //System.out.println(split_feature_value);

            //Node()

            int left_cnt = 0;
            for (int j = 0; j < X.size(); j++) {
                if (X.get(j).get(split_feature_index) <= split_feature_value) {
                    left_cnt++;
                }
            }

            if (left_cnt == 0 || left_cnt == X.size()) {
                return;
            } else {
                var lx= new ArrayList<ArrayList<Double>>();
                var ly= new ArrayList<Integer>();
                var rx= new ArrayList<ArrayList<Double>>();
                var ry= new ArrayList<Integer>();

                for (int j = 0; j < X.size(); j++) {
                    if (X.get(j).get(split_feature_index) <= split_feature_value) {
                        lx.add(X.get(j));
                        ly.add(y.get(j));
                    } else {
                        rx.add(X.get(j));
                        ry.add(y.get(j));
                    }
                }

                right_child = new Node(rx, ry);
                left_child = new Node(lx, ly);

            }
        }
    }

    public double predict(ArrayList<Double> x) {
        if (left_child == null) {
            // 返回y的平均值
            double y_avg = 0;
            for (var i : y) { y_avg += i;}
            return y_avg / y.size();
        } else {
            return x.get(split_feature_index) <= split_feature_value ?
                    left_child.predict(x) : right_child.predict(x);
        }
    }
}

class Tree {
    Node root;
    Tree lf;
    Tree rt;
    Tree (ArrayList<ArrayList<Double>> X, ArrayList<Integer> y) {
        root = new Node(X, y);
    }

    Tree (Node tr) {
        root = tr;
    }

    public void growth() {

        root.split();
        if (root.right_child != null ) {
            rt = new Tree(root.right_child);
            rt.growth();
        }
        if (root.left_child != null) {
            lf = new Tree(root.left_child);
            lf.growth();
        }
    }

    public double predict(ArrayList<Double> x) {
        return root.predict(x);
    }

}

public class Main {

    public static void main(String[] args) {
        // 检验我们的决策树构造是否正确
        var X = new ArrayList<ArrayList<Double>>();
        var y = new ArrayList<Integer>();
        var testX = new ArrayList<ArrayList<Double>>();
        var testy = new ArrayList<Integer>();

        Random r = new Random(1);
        for(int i = 0 ; i < 1000 ;  i++) {
            double x1 = r.nextDouble();
            double x2 = r.nextDouble();
            var pt = new ArrayList<Double>();
            pt.add(x1);
            pt.add(x2);
            X.add(pt);

            if ( x1 * x1 + x2 * x2 < 1) {
                y.add(1);
            } else {
                y.add(-1);
            }
        }

        for(int i = 0 ; i < 100 ;  i++) {
            double x1 = r.nextDouble();
            double x2 = r.nextDouble();
            var pt = new ArrayList<Double>();
            pt.add(x1);
            pt.add(x2);
            testX.add(pt);

            if ( x1 * x1 + x2 * x2 < 1) {
                testy.add(1);
            } else {
                testy.add(-1);
            }

        }


        var DS = new Tree(X, y);
        DS.growth();

        double fit_score = 0;
        for (int i = 0; i < X.size(); i++) {
            int predict_val =  DS.predict(X.get(i)) < 0 ? -1 : 1;
            if (predict_val == y.get(i)) {
                fit_score += 100. / X.size();
            }
        }

        double test_score = 0;
        for (int i = 0; i < testX.size(); i++) {
            int predict_val =  DS.predict(testX.get(i)) < 0 ? -1 : 1;
            if (predict_val == testy.get(i)) {
                test_score += 100 / testX.size();
            }
        }

        System.out.println("fit_score:" + fit_score);
        System.out.println("test_score:" + test_score);


    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值