本系统实现了决策树生成,只要输入合适的数据集,系统就可以生成一棵决策树。
数据集的输入使用二维数组,输入的个数为:序号+特征+分类结果。同时要把特征名以及对应的特征值传给程序,如此一来系统就可以建决策树。
关于决策树的定义这里不再列出,CSDN上有很多类似的博客。这些博客实现的Java代码很长,又没有注释,我看不懂,所以自己实现了一遍。我这里不再多加赘述。使用Java实现决策树个人觉得是不太明智的做法,比较繁琐,建议使用python实现。以下是代码,大部分应该是有注释的,后面可能是调到心累所有有些地方没有,留个纪念。原理还是很好懂的。
package homework;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
public class DecisionTree {
LinkedList<String[]> dataList = new LinkedList<String[]>();//保存训练集的数据,结构为:序号+属性+类别
LinkedList<String> attribute = new LinkedList<String>();//存放属性个数
DecisionTree father;
String attriValue;
String attriDivide;
LinkedList<DecisionTree> child;
HashMap<String,LinkedList<String>> attributeValue = new HashMap<String,LinkedList<String>>();//存放属性及其对应的属性值
public DecisionTree(String[][] data,HashMap<String,LinkedList<String>> attributeValue,DecisionTree father,String attriDivide,String attriValue) {
this.father = father;
this.attriDivide = attriDivide;
this.attriValue = attriValue;
this.attributeValue = attributeValue;
getDataAndAttribute(data,attributeValue);
if(!detectEnd()) {//当前节点不为终叶节点,可以继续往下分
String attriRoot = bestAttri();//得到当前的划分属性
Map<String,LinkedList<String[]>> child = divideByAttribute(attriRoot);//得到划分后的所有东西
this.child = new LinkedList<DecisionTree>();//后面划分的节点属于当前的儿子
//获得不同键值下面的数据集,先获取键值集
Set<String> keySet = child.keySet();
//遍历键值集
Iterator<String> keys = keySet.iterator();
while(keys.hasNext()) {
String key = keys.next();
LinkedList<String[]> childData = child.get(key);//获取此键值下面的所有数据集
HashMap<String,LinkedList<String>> newAttribute = this.attributeValue;
newAttribute.remove(attriRoot);
if(childData.size()==0)continue;
String[][] datas = new String[childData.size()][childData.get(0).length];//将child下面的data改为二维数组的形式
for(int i=0;i<childData.size();i++) {
datas[i] = childData.get(i);
}
DecisionTree childNode = new DecisionTree(datas,newAttribute,this,attriRoot,key);
this.child.add(childNode);
}
}
}
public void getDataAndAttribute(String[][] data,HashMap<String,LinkedList<String>> attribute) {
for(int i=0;i<data.length;i++) {//将所有数据集压入类的数据集中
this.dataList.add(data[i]);
}
Set<String> keySet = attribute.keySet();
Iterator<String> it = keySet.iterator();//将map里面的键值写入到本地的attribute表中,作为属性表
while(it.hasNext()) {
String s = it.next();
this.attribute.add(s);
}
}
boolean detectEnd() {//判断当前节点是否为终叶节点
Set<String> detect = new HashSet<String>();
for(int i=0;i<dataList.size();i++) {
String[] temp = dataList.get(i);
detect.add(temp[temp.length-1]);
}//当所有分类结果最终只有一种结果,就是终叶节点
if(detect.size()==1)return true;
else return false;
}
double calEntropy(String attribute) {
double result = 0;//所有属性值的熵值和
double totalNum = this.dataList.size();//总数据集的个数
Map<String,LinkedList<String[]>> divide = divideByAttribute(attribute);//得到按属性attribute值分类的结果
Set<String> keySet = divide.keySet();//得到所有键值
Iterator<String> iterator = keySet.iterator();//遍历所有键值
while(iterator.hasNext()) {
String key = iterator.next();
LinkedList<String[]> values = divide.get(key);//获得当前键值下所有的数据集
int count = values.size();//当前键值下的数据个数
Set<String> resultSet = new HashSet<String>();//使用Set来判断结果中有多少种
for(int i=0;i<count;i++) {
String[] temp = values.get(i);
resultSet.add(temp[temp.length-1]);
}
Iterator<String> iteratorResult = resultSet.iterator();//遍历结果种数
double resultInAttribute = 0;//当前属性值下的熵值
int countI;
while(iteratorResult.hasNext()) {
countI=0;//计算不同结果各自有多少种
String resultI = (String)iteratorResult.next();//当前的结果
for(int i=0;i<count;i++) {
String[] temp = values.get(i);//与数据集中的结果比较
if(temp[temp.length-1].equals(resultI))countI++;//如果数据与当前结果相同,计数加一
}
//计算得到当前属性值的熵
resultInAttribute = resultInAttribute - ((double)countI/count)*(Math.log((double)countI/count)/Math.log(2));
}
result = result + ((double)count/totalNum)*resultInAttribute;
}
return result;
}
public String bestAttri() {
double min = 100;
String choose = "";
for(int i=0;i<this.attribute.size();i++) {
double cal = calEntropy(this.attribute.get(i));
if(min>cal) {
min = cal;
choose = this.attribute.get(i);
}
}
return choose;
}
Map<String,LinkedList<String[]>> divideByAttribute(String attribute){
LinkedList<String> attriValue = this.attributeValue.get(attribute);//获得当前属性下的属性值
String[] content = this.dataList.get(0);//从本类的数据中拿出第0个,为了判断当前的attribute在哪一列
int col=0;
for(int i=1;i<content.length-1;i++) {
if(attriValue.contains(content[i])) {
col = i;//找到当前attribute所在的列
break;
}
}
Map<String,LinkedList<String[]>> result = new HashMap<String,LinkedList<String[]>>();//结果集
//下面开始按attribute的值对dataList分类
for(int i=0;i<attriValue.size();i++) {//遍历
LinkedList<String[]> resultValue = new LinkedList<String[]>();//当前attribute[i]的值
for(int j=0;j<this.dataList.size();j++) {
String[] temp = this.dataList.get(j);
if(temp[col].equals(attriValue.get(i)))resultValue.add(temp);
}
if(resultValue.size()!=0)result.put(attriValue.get(i), resultValue);
}
return result;
}
}