1. 算法原理
CART算法: 为二叉决策树, 能够同时处理离散属性和连续属性,计算所有属性的Gini值,根据最大Gini值进行分裂生成决策树。(离散属性可以用二进制枚举分成两部分, 连续值根据值进行排序后遍历的时候同时更新类别判断值的状态,复杂度降到o(n))
2. 代码实现
Node.java
package com.clxk1997.model;
/**
* @Description 决策树节点
* @Author Clxk
* @Date 2019/5/5 15:53
* @Version 1.0
*/
public class Node {
private String field;
private String value;
constructor...
getters ans setters....
}
Field.java
package com.clxk1997.model;
/**
* @Description 属性标签
* @Author Clxk
* @Date 2019/5/5 14:45
* @Version 1.0
*/
public class Field {
private String fieldName;
private int fieldClass;
constructor...
getters ans setters....
}
DecisionTree.java
package com.clxk1997.model;
import com.clxk1997.utils.Fields;
import java.util.ArrayList;
import java.util.List;
/**
* @Description 决策树
* @Author Clxk
* @Date 2019/5/5 15:59
* @Version 1.0
*/
public class DecisionTree {
private Node node;
private DecisionTree lTree, rTree;
constructor...
getters ans setters....
public static DecisionTree initTree() {
DecisionTree root = new DecisionTree();
root.setNode(new Node("root","root"));
root.setlTree(null);
root.setrTree(null);
return root;
}
public static void print(DecisionTree root) {
if(root == null) return;
System.out.println("属性: " + root.node.getField() + " " + "判断: " + root.getNode().getValue());
print(root.lTree);
print(root.rTree);
}
public static String search(DecisionTree tree, ArrayList<String> list, List<Field> fields_copy) {
if(tree.getNode().getField().equals("leaf")) return tree.getNode().getValue();
if(tree.getlTree().getNode().getField().equals("leaf")) return tree.getlTree().getNode().getValue();
DecisionTree lTree = tree.getlTree();
DecisionTree rTree = tree.getrTree();
String lfield = lTree.getNode().getField();
String lvalue = lTree.getNode().getValue();
int idx = 0;
boolean isContinous = true;
for(int i = 0; i < fields_copy.size(); i++) {
if(fields_copy.get(i).getFieldName().equals(lfield)) {
if(fields_copy.get(i).getFieldClass() == Fields.CONTINOUS) {
isContinous = true;
} else {
isContinous = false;
}
idx = i;
break;
}
}
if(isContinous) {
double value = Double.valueOf(lvalue.substring(2));
if(Double.valueOf(list.get(idx)) <= value) {
return search(lTree, list, fields_copy);
} else {
return search(rTree, list, fields_copy);
}
} else {
String value = "|" + list.get(idx) + "|";
if(lvalue.indexOf(value) > -1) {
return search(lTree, list, fields_copy);
} else {
return search(rTree, list, fields_copy);
}
}
}
}
Fields.java
package com.clxk1997.utils;
/**
* @Description 属性标签
* @Author Clxk
* @Date 2019/5/5 14:47
* @Version 1.0
*/
public final class Fields {
public static final int ID = 1;
public static final int CONTINOUS = 2;
public static final int DISCRETE = 3;
public static final int CLASS = 4;
}
Cart.java
package com.clxk1997.cart;
import com.clxk1997.model.DecisionTree;
import com.clxk1997.model.Field;
import com.clxk1997.model.Node;
import com.clxk1997.utils.Fields;
import java.io.*;
import java.lang.reflect.Array;
import java.util.*;
/**
* @Description Cart算法实现
* @Author Clxk
* @Date 2019/5/5 14:42
* @Version 1.0
*/
public class Cart {
//属性
private static List<Field> fields = new ArrayList<>();
private static List<Field> fields_copy = new ArrayList<>();
//训练集
private static List<ArrayList<String>> trains = new ArrayList<>();
//测试集
private static List<ArrayList<String>> tests = new ArrayList<>();
//决策树
private static DecisionTree tree = DecisionTree.initTree();
public static void main(String[] args) {
try {
inputModel();
training();
inputTest();
OutputResult();
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 输入训练集
*/
public static void inputModel() throws FileNotFoundException {
inputFields();
inputTrains();
}
/**
* 输入训练集
*/
public static void inputTrains() throws FileNotFoundException {
System.out.println("请输入训练集:(\"exit\"结束)");
Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_trains.txt"));
String strLine;
while(true) {
strLine = scan.nextLine();
if(strLine.equals("exit"))break;
String[] split = strLine.split("\\s");
ArrayList<String> train = new ArrayList<>();
for(String s: split) {
train.add(s);
}
trains.add(train);
}
}
/**
* 输入属性
*/
public static void inputFields() throws FileNotFoundException {
System.out.println("请输入属性数量:");
Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_fields.txt"));
int fieldCnt = scan.nextInt();
scan.nextLine();
for(int i = 0; i < fieldCnt; i++) {
System.out.println("请输入第" + (i+1) + "个属性及属性类型(ID-1 | CONTINOUS-2 | DISCRETE-3 | CLASS-4)");
String strLine = scan.nextLine();
String[] fieldStr = strLine.split("\\s");
int fieldClass = Integer.valueOf(fieldStr[1]);
Field field = new Field(fieldStr[0], fieldClass);
fields.add(field);
fields_copy.add(field);
}
}
/**
* 训练
*/
public static void training() {
train(fields, trains, tree);
System.out.println("===================训练完成====================");
}
public static void train(List<Field> fields, List<ArrayList<String>> trains, DecisionTree tree) {
if(trains.size() == 0) return;
if(isPure(trains)) {
tree.setlTree(new DecisionTree(new Node("leaf", trains.get(0).get(trains.get(0).size()-1)), null, null));
tree.setrTree(new DecisionTree(new Node("leaf", trains.get(0).get(trains.get(0).size()-1)), null, null));
return;
}
String conflictAns = isConflict(trains);
if(conflictAns != "") {
tree.setlTree(new DecisionTree(new Node("leaf", conflictAns), null, null));
tree.setrTree(new DecisionTree(new Node("leaf", conflictAns), null, null));
return;
}
double classGini = getClassGini(trains);
System.out.println("类属性的Gini值为: " + classGini);
int maxIndex = 0;
double maxGini = -100000;
double ansResult = 0;
boolean isContinous = true;
String[] keys = null;
Object[] result = new Object[3];
for (int i = 1; i < fields.size() - 1; i++) {
if(fields.get(i).getFieldClass() == Fields.ID) continue;
if(fields.get(i).getFieldClass() == Fields.CLASS) continue;
if(fields.get(i).getFieldClass() == Fields.CONTINOUS) {
result = getContinousGini(trains, i, classGini);
System.out.println("Gini值计算: " + "属性: " + fields.get(i).getFieldName() + " Gini值: " + (double)result[0]);
if((double)result[0] > maxGini) {
maxGini = (double)result[0];
ansResult = (double)result[1];
isContinous = true;
maxIndex = i;
}
} else {
result = getDiscreteGini(trains, i, classGini);
System.out.println("Gini值计算: " + "属性: " + fields.get(i).getFieldName() + " Gini值: " + (double)result[0]);
if((double)result[0] > maxGini) {
maxGini = (double)result[0];
ansResult = (double)result[1];
keys = (String[]) result[2];
isContinous = false;
maxIndex = i;
}
}
}
//如果Gini值最大的节点是连续节点
if(fields.size() > 2 && isContinous) {
System.out.println("最大Gini值: " + "属性: " + fields.get(maxIndex).getFieldName() + " 分隔条件: " + "<= " + String.valueOf(ansResult) + " and " + "> " + String.valueOf(ansResult));
Node lnode = new Node(fields.get(maxIndex).getFieldName(), "<=" + String.valueOf(ansResult));
Node rnode = new Node(fields.get(maxIndex).getFieldName(), ">" + String.valueOf(ansResult));
DecisionTree lTree = new DecisionTree(lnode, null, null);
DecisionTree rTree = new DecisionTree(rnode, null, null);
tree.setrTree(rTree);
tree.setlTree(lTree);
ArrayList<ArrayList<String>> llist = new ArrayList<>();
ArrayList<ArrayList<String>> rlist = new ArrayList<>();
ArrayList<Field> fields_copy1 = new ArrayList<>();
ArrayList<Field> fields_copy2 = new ArrayList<>();
for(Field field : fields) {
fields_copy1.add(field);
fields_copy2.add(field);
}
fields_copy1.remove(maxIndex);
fields_copy2.remove(maxIndex);
for(ArrayList<String> list : trains) {
if(Double.valueOf(list.get(maxIndex)) <= ansResult) {
list.remove(maxIndex);
llist.add(list);
} else {
list.remove(maxIndex);
rlist.add(list);
}
}
train(fields_copy1, llist, lTree);
train(fields_copy2, rlist, rTree);
} else if(fields.size() > 2) { //Gini值最大的节点是离散节点
String lfield = "";
String rfield = "";
for(int i = 0; i < keys.length; i++) {
if(((int)ansResult & (1 << i)) != 0) {
rfield += "|" + keys[i] + "|";
} else {
lfield += "|" + keys[i] + "|";
}
}
System.out.println("最大Gini值: " + "属性: " + fields.get(maxIndex).getFieldName() + " 分隔条件: " + lfield + " and " + rfield);
Node lnode = new Node(fields.get(maxIndex).getFieldName(), lfield);
Node rnode = new Node(fields.get(maxIndex).getFieldName(), rfield);
DecisionTree lTree = new DecisionTree(lnode, null, null);
DecisionTree rTree = new DecisionTree(rnode, null, null);
tree.setlTree(lTree);
tree.setrTree(rTree);
ArrayList<ArrayList<String>> llist = new ArrayList<>();
ArrayList<ArrayList<String>> rlist = new ArrayList<>();
ArrayList<Field> fields_copy1 = new ArrayList<>();
ArrayList<Field> fields_copy2 = new ArrayList<>();
for(Field field : fields) {
fields_copy1.add(field);
fields_copy2.add(field);
}
fields_copy1.remove(maxIndex);
fields_copy2.remove(maxIndex);
for(ArrayList<String> list : trains) {
if(lfield.indexOf("|" + list.get(maxIndex) + "|") > -1) {
list.remove(maxIndex);
llist.add(list);
} else {
list.remove(maxIndex);
rlist.add(list);
}
}
train(fields_copy1, llist, lTree);
train(fields_copy2, rlist, rTree);
}
}
/**
* 混乱节点(普通属性相同类属性不相同)
* @param trains
* @return
*/
public static String isConflict(List<ArrayList<String>> trains) {
int len = trains.get(0).size();
for(int i = 1; i < len-1; i++) {
Set<String> set = new HashSet<>();
for(ArrayList<String> train : trains) {
set.add(train.get(i));
if(set.size() > 1) return "";
}
}
int lcnt = 0, rcnt = 0;
String lcur = trains.get(0).get(trains.get(0).size()-1);
String rcur = "";
for(ArrayList<String> train : trains) {
if(train.get(train.size()-1).equals(lcur)) lcnt++;
else {
rcnt++;
rcur = train.get(train.size()-1);
}
}
if(lcnt >= rcnt) return lcur;
else return rcur;
}
/**
* 判断当前节点是不是纯节点
* @param trains
* @return
*/
public static boolean isPure(List<ArrayList<String>> trains) {
Set<String> set = new HashSet<>();
for(ArrayList<String> list : trains) {
set.add(list.get(list.size()-1));
if(set.size() > 1) return false;
}
return true;
}
/**
* 计算连续值Gini系数
* @param trains
* @param idx
* @param classGini
* @return
*/
public static Object[] getContinousGini(List<ArrayList<String>> trains, int idx, double classGini) {
String classValue = trains.get(0).get(trains.get(0).size()-1);
trains.sort(new Comparator<ArrayList<String>>() {
@Override
public int compare(ArrayList<String> o1, ArrayList<String> o2) {
double a = Double.valueOf(o1.get(idx));
double b = Double.valueOf(o2.get(idx));
if(a > b) return 1;
else if(a == b) return 0;
else return -1;
}
});
int lcnt0 = 0, lcnt1 = 0, rcnt0 = 0, rcnt1 = 0;
double maxgini = 0;
double ansValue = 0;
String curStr = trains.get(0).get(trains.get(0).size()-1);
//初始化lcnt, rcnt
for(ArrayList<String> train : trains) {
if(train.get(train.size()-1).equals(curStr)) {
rcnt0++;
}
rcnt1++;
}
for(int i = 0; i < trains.size()-1; i++) {
double curt = (Double.valueOf(trains.get(i).get(idx)) + Double.valueOf(trains.get(i+1).get(idx))) / 2.0;
lcnt1++;
rcnt1--;
if(trains.get(i).get(trains.get(i).size()-1).equals(curStr)) {
rcnt0--;
lcnt0++;
}
/*for (int j = 0; j < trains.size(); j++) {
if(Double.valueOf(trains.get(j).get(idx)) > curt) break;
lcnt1++;
if(trains.get(j).get(trains.get(j).size()-1).equals(classValue)) {
lcnt0++;
}
}
for (int j = trains.size()-1; j >= 0; j--) {
if(Double.valueOf(trains.get(j).get(idx)) <= curt) break;
rcnt1++;
if(trains.get(j).get(trains.get(j).size()-1).equals(classValue)) {
rcnt0++;
}
}*/
double cur = calculDeta(trains.size(), classGini, lcnt0, lcnt1, rcnt0, rcnt1);
if(cur > maxgini) {
maxgini = cur;
ansValue = curt;
}
}
Object[] ans = new Object[2];
ans[0] = maxgini;
ans[1] = ansValue;
return ans;
}
/**
* 计算离散值Gini系数
* @param trains
* @param idx
* @param classGini
* @return
*/
public static Object[] getDiscreteGini(List<ArrayList<String>> trains, int idx, double classGini) {
String classValue = trains.get(0).get(trains.get(0).size()-1);
Map<String, Integer> fieldmp = new HashMap<>();
Map<String, Map<String, Integer>> classmp = new HashMap<>();
for(int i = 0; i < trains.size(); i++) {
ArrayList<String> cur = trains.get(i);
String key = cur.get(idx);
String classKey = cur.get(cur.size()-1);
if(fieldmp.containsKey(key)) {
fieldmp.put(key, fieldmp.get(key)+1);
if(classmp.get(key).containsKey(classKey)) {
classmp.get(key).put(classKey, classmp.get(key).get(classKey) + 1);
} else {
classmp.get(key).put(classKey, 1);
}
} else {
fieldmp.put(key, 1);
Map<String, Integer> cd = new HashMap<>();
cd.put(classKey, 1);
classmp.put(key, cd);
}
}
String[] keys = new String[classmp.size()];
classmp.keySet().toArray(keys);
int lcnt0 = 0, lcnt1 = 0, rcnt0 = 0, rcnt1 = 0;
double maxgini = 0;
double ansValue = 0;
if(fieldmp.size() != 1) {
for(int i = 1; i < (1 << fieldmp.size())-1; i++) {
lcnt0 = lcnt1 = rcnt1 = rcnt0 = 0;
for (int j = 0; j < fieldmp.size(); j++) {
if((i&(1 << j)) != 0) {
if(classmp.get(keys[j]).containsKey(classValue))
lcnt0 += classmp.get(keys[j]).get(classValue);
lcnt1 += fieldmp.get(keys[j]);
} else {
if(classmp.get(keys[j]).containsKey(classValue))
rcnt0 += classmp.get(keys[j]).get(classValue);
rcnt1 += fieldmp.get(keys[j]);
}
}
double cur = calculDeta(trains.size(), classGini, lcnt0, lcnt1, rcnt0, rcnt1);
if(cur > maxgini) {
maxgini = cur;
ansValue = i;
}
}
}
Object[] ans = new Object[3];
ans[0] = maxgini;
ans[1] = ansValue;
ans[2] = keys;
return ans;
}
/**
* 计算deta值
* @param size
* @param classGini
* @param lcnt0
* @param lcnt1
* @param rcnt0
* @param rcnt1
* @return
*/
public static double calculDeta(int size, double classGini, int lcnt0, int lcnt1, int rcnt0, int rcnt1) {
return classGini - ((double)lcnt1 / (double)size) * calculGini(lcnt0, lcnt1) - ((double)rcnt1 / (double)size) * calculGini(rcnt0, rcnt1);
}
/**
* 计算左右节点Gini值
* @param lcnt0
* @param lcnt1
* @return
*/
public static double calculGini(int lcnt0, int lcnt1) {
double a = (double)lcnt0 / (double)lcnt1;
double b = (double)(lcnt1-lcnt0) / (double)lcnt1;
return 1d - a * a - b * b;
}
public static double getClassGini(List<ArrayList<String>> trains) {
double ans = 1;
Map<String, Integer> mp = new HashMap<>();
for(int i = 0; i < trains.size(); i++) {
String key = trains.get(i).get(trains.get(i).size()-1);
if(mp.containsKey(key)) mp.put(key, mp.get(key)+1);
else mp.put(key, 1);
}
for(Map.Entry<String,Integer> entry : mp.entrySet()) {
ans -= ((double)entry.getValue() / (double)trains.size()) * ((double)entry.getValue() / (double)trains.size());
}
return ans;
}
/**
* 输入测试集
*/
public static void inputTest() throws FileNotFoundException {
System.out.println("请输入测试集:(\"exit\"结束)");
Scanner scan = new Scanner(new FileInputStream("D:\\CART\\src\\input_test.txt"));
String strLine;
while(true) {
strLine = scan.nextLine();
if(strLine.equals("exit"))break;
String[] split = strLine.split("\\s");
ArrayList<String> test = new ArrayList<>();
for(String s: split) {
test.add(s);
}
tests.add(test);
}
System.out.println("===================开始预测====================");
}
/**
* 输出测试结果
*/
public static void OutputResult() throws Exception {
int tr = 0;
DecisionTree.print(tree);
String s = "";
FileWriter writer = new FileWriter("D:\\CART\\src\\output_ans.txt",true);
BufferedWriter bufferedWriter = new BufferedWriter(writer);
for(ArrayList<String> list : tests) {
String ans = DecisionTree.search(tree, list, fields_copy);
bufferedWriter.write(ans + "\n");
}
bufferedWriter.close();
writer.close();
System.out.println("===================预测完成====================");
}
}