决策树代码,java实现
import java.util.ArrayList;
import java.util.Random;
class Node {
int split_feature_index;
double split_feature_value;
Node left_child = null;
Node right_child = null;
ArrayList<ArrayList<Double>> X;
ArrayList<Integer> y;
public Node(ArrayList<ArrayList<Double>> X,
ArrayList<Integer> y) {
this.X = X;
this.y = y;
}
private double get_Variance(ArrayList<Integer> x) {
double avg = 0;
double variance = 0;
for(var i : x) {
avg += i;
variance += i * i;
}
avg /= x.size();
variance -= avg * avg * x.size();
return variance;
}
public Node (int index, int value) {
split_feature_index = index;
split_feature_value = value;
}
public void split() {
var y_variance = get_Variance(y);
if (X.size() <= 5 || y_variance <= 0.0001) {
return;
} else {
double min_loss = y_variance;
//System.out.println(min_loss);
// 枚举所有分点
for (int index = 0; index < X.get(0).size(); index ++) {
// 第index个属性
for (int i = 0; i < X.size(); i++) {
// <= a[i]的都分入
double left_avg = 0, right_avg = 0;
int left_cnt = 0, right_cnt = 0;
for (int j = 0; j < X.size(); j++) {
if (X.get(j).get(index) <= X.get(i).get(index)) {
left_avg += y.get(j);
left_cnt ++;
} else {
right_avg += y.get(j);
}
}
right_cnt = X.size() - left_cnt;
left_avg /= left_cnt;
right_avg = right_cnt != 0 ?
right_avg / right_cnt : 0;
//System.out.println(left_avg);
//System.out.println(right_avg);
double left_se = 0, right_se = 0;
for (int j = 0; j < X.size(); j++) {
if (X.get(j).get(index) <= X.get(i).get(index)) {
left_se += ( y.get(j) - left_avg
) * ( y.get(j) - left_avg);
} else {
right_se += ( y.get(j) - right_avg
) * ( y.get(j) - right_avg);
}
}
double se_reduce = left_se + right_se;
//System.out.println(se_reduce);
//System.out.println(X.get(i).get(index));
if (se_reduce < min_loss) {
min_loss = se_reduce;
split_feature_value = X.get(i).get(index);
split_feature_index = index;
}
}
}
//System.out.println(split_feature_index);
//System.out.println(split_feature_value);
//Node()
int left_cnt = 0;
for (int j = 0; j < X.size(); j++) {
if (X.get(j).get(split_feature_index) <= split_feature_value) {
left_cnt++;
}
}
if (left_cnt == 0 || left_cnt == X.size()) {
return;
} else {
var lx= new ArrayList<ArrayList<Double>>();
var ly= new ArrayList<Integer>();
var rx= new ArrayList<ArrayList<Double>>();
var ry= new ArrayList<Integer>();
for (int j = 0; j < X.size(); j++) {
if (X.get(j).get(split_feature_index) <= split_feature_value) {
lx.add(X.get(j));
ly.add(y.get(j));
} else {
rx.add(X.get(j));
ry.add(y.get(j));
}
}
right_child = new Node(rx, ry);
left_child = new Node(lx, ly);
}
}
}
public double predict(ArrayList<Double> x) {
if (left_child == null) {
// 返回y的平均值
double y_avg = 0;
for (var i : y) { y_avg += i;}
return y_avg / y.size();
} else {
return x.get(split_feature_index) <= split_feature_value ?
left_child.predict(x) : right_child.predict(x);
}
}
}
class Tree {
Node root;
Tree lf;
Tree rt;
Tree (ArrayList<ArrayList<Double>> X, ArrayList<Integer> y) {
root = new Node(X, y);
}
Tree (Node tr) {
root = tr;
}
public void growth() {
root.split();
if (root.right_child != null ) {
rt = new Tree(root.right_child);
rt.growth();
}
if (root.left_child != null) {
lf = new Tree(root.left_child);
lf.growth();
}
}
public double predict(ArrayList<Double> x) {
return root.predict(x);
}
}
public class Main {
public static void main(String[] args) {
// 检验我们的决策树构造是否正确
var X = new ArrayList<ArrayList<Double>>();
var y = new ArrayList<Integer>();
var testX = new ArrayList<ArrayList<Double>>();
var testy = new ArrayList<Integer>();
Random r = new Random(1);
for(int i = 0 ; i < 1000 ; i++) {
double x1 = r.nextDouble();
double x2 = r.nextDouble();
var pt = new ArrayList<Double>();
pt.add(x1);
pt.add(x2);
X.add(pt);
if ( x1 * x1 + x2 * x2 < 1) {
y.add(1);
} else {
y.add(-1);
}
}
for(int i = 0 ; i < 100 ; i++) {
double x1 = r.nextDouble();
double x2 = r.nextDouble();
var pt = new ArrayList<Double>();
pt.add(x1);
pt.add(x2);
testX.add(pt);
if ( x1 * x1 + x2 * x2 < 1) {
testy.add(1);
} else {
testy.add(-1);
}
}
var DS = new Tree(X, y);
DS.growth();
double fit_score = 0;
for (int i = 0; i < X.size(); i++) {
int predict_val = DS.predict(X.get(i)) < 0 ? -1 : 1;
if (predict_val == y.get(i)) {
fit_score += 100. / X.size();
}
}
double test_score = 0;
for (int i = 0; i < testX.size(); i++) {
int predict_val = DS.predict(testX.get(i)) < 0 ? -1 : 1;
if (predict_val == testy.get(i)) {
test_score += 100 / testX.size();
}
}
System.out.println("fit_score:" + fit_score);
System.out.println("test_score:" + test_score);
}
}