使用Java实现的决策树

6 篇文章 0 订阅
5 篇文章 0 订阅

      本系统实现了决策树生成,只要输入合适的数据集,系统就可以生成一棵决策树。

      数据集的输入使用二维数组,输入的个数为:序号+特征+分类结果。同时要把特征名以及对应的特征值传给程序,如此一来系统就可以建决策树。

      关于决策树的定义这里不再列出,CSDN上有很多类似的博客。这些博客实现的Java代码很长,又没有注释,我看不懂,所以自己实现了一遍。我这里不再多加赘述。使用Java实现决策树个人觉得是不太明智的做法,比较繁琐,建议使用python实现。以下是代码,大部分应该是有注释的,后面可能是调到心累所有有些地方没有,留个纪念。原理还是很好懂的。

package homework;

import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;


public class DecisionTree {
	LinkedList<String[]> dataList = new LinkedList<String[]>();//保存训练集的数据,结构为:序号+属性+类别
	LinkedList<String> attribute = new LinkedList<String>();//存放属性个数
	DecisionTree father;
	String attriValue;
	String attriDivide;
	LinkedList<DecisionTree> child;
	HashMap<String,LinkedList<String>> attributeValue = new HashMap<String,LinkedList<String>>();//存放属性及其对应的属性值
	public DecisionTree(String[][] data,HashMap<String,LinkedList<String>> attributeValue,DecisionTree father,String attriDivide,String attriValue) {
		this.father = father;
		this.attriDivide = attriDivide;
		this.attriValue = attriValue;
		this.attributeValue = attributeValue;
		getDataAndAttribute(data,attributeValue);
		if(!detectEnd()) {//当前节点不为终叶节点,可以继续往下分
			String attriRoot = bestAttri();//得到当前的划分属性
			Map<String,LinkedList<String[]>> child = divideByAttribute(attriRoot);//得到划分后的所有东西
			this.child = new LinkedList<DecisionTree>();//后面划分的节点属于当前的儿子
			//获得不同键值下面的数据集,先获取键值集
			Set<String> keySet = child.keySet();
			//遍历键值集
			Iterator<String> keys = keySet.iterator();
			while(keys.hasNext()) {
				String key = keys.next();
				LinkedList<String[]> childData = child.get(key);//获取此键值下面的所有数据集
				HashMap<String,LinkedList<String>> newAttribute = this.attributeValue;
				newAttribute.remove(attriRoot);
				if(childData.size()==0)continue;
				String[][] datas = new String[childData.size()][childData.get(0).length];//将child下面的data改为二维数组的形式
				for(int i=0;i<childData.size();i++) {
					datas[i] = childData.get(i);
				}
				DecisionTree childNode = new DecisionTree(datas,newAttribute,this,attriRoot,key);
				this.child.add(childNode);
			}
		}
	}
	public void getDataAndAttribute(String[][] data,HashMap<String,LinkedList<String>> attribute) {
		for(int i=0;i<data.length;i++) {//将所有数据集压入类的数据集中
			this.dataList.add(data[i]);
		}
		Set<String> keySet = attribute.keySet();
		Iterator<String> it = keySet.iterator();//将map里面的键值写入到本地的attribute表中,作为属性表
		while(it.hasNext()) {
			String s = it.next();
			this.attribute.add(s);
		}
	}
	boolean detectEnd() {//判断当前节点是否为终叶节点
		Set<String> detect = new HashSet<String>();
		for(int i=0;i<dataList.size();i++) {
			String[] temp = dataList.get(i);
			detect.add(temp[temp.length-1]);
		}//当所有分类结果最终只有一种结果,就是终叶节点
		if(detect.size()==1)return true;
		else return false;
	}
	double calEntropy(String attribute) {
		double result = 0;//所有属性值的熵值和
		double totalNum = this.dataList.size();//总数据集的个数
		Map<String,LinkedList<String[]>> divide = divideByAttribute(attribute);//得到按属性attribute值分类的结果
		Set<String> keySet = divide.keySet();//得到所有键值
		Iterator<String> iterator = keySet.iterator();//遍历所有键值
		while(iterator.hasNext()) {
			String key = iterator.next();
			LinkedList<String[]> values = divide.get(key);//获得当前键值下所有的数据集
			int count = values.size();//当前键值下的数据个数
			Set<String> resultSet = new HashSet<String>();//使用Set来判断结果中有多少种
			for(int i=0;i<count;i++) {
				String[] temp = values.get(i);
				resultSet.add(temp[temp.length-1]);
			}
			Iterator<String> iteratorResult = resultSet.iterator();//遍历结果种数
			double resultInAttribute = 0;//当前属性值下的熵值
			int countI;
			while(iteratorResult.hasNext()) {
				countI=0;//计算不同结果各自有多少种
				String resultI = (String)iteratorResult.next();//当前的结果
				for(int i=0;i<count;i++) {
					String[] temp = values.get(i);//与数据集中的结果比较
					if(temp[temp.length-1].equals(resultI))countI++;//如果数据与当前结果相同,计数加一
				}
				//计算得到当前属性值的熵
				resultInAttribute = resultInAttribute - ((double)countI/count)*(Math.log((double)countI/count)/Math.log(2));
			}
			result = result + ((double)count/totalNum)*resultInAttribute;
		}
		return result;
	}
	public String bestAttri() {
		double min = 100;
		String choose = "";
		for(int i=0;i<this.attribute.size();i++) {
			double cal = calEntropy(this.attribute.get(i));
			if(min>cal) {
				min = cal;
				choose = this.attribute.get(i);
			}
		}
		return choose;
	}
	Map<String,LinkedList<String[]>> divideByAttribute(String attribute){
		LinkedList<String> attriValue = this.attributeValue.get(attribute);//获得当前属性下的属性值
		String[] content = this.dataList.get(0);//从本类的数据中拿出第0个,为了判断当前的attribute在哪一列
		int col=0;
		for(int i=1;i<content.length-1;i++) {
			if(attriValue.contains(content[i])) {
				col = i;//找到当前attribute所在的列
				break;
			}
		}
		Map<String,LinkedList<String[]>> result = new HashMap<String,LinkedList<String[]>>();//结果集
		//下面开始按attribute的值对dataList分类
		for(int i=0;i<attriValue.size();i++) {//遍历
			LinkedList<String[]> resultValue = new LinkedList<String[]>();//当前attribute[i]的值
			for(int j=0;j<this.dataList.size();j++) {
				String[] temp = this.dataList.get(j);
				if(temp[col].equals(attriValue.get(i)))resultValue.add(temp);
			}
			if(resultValue.size()!=0)result.put(attriValue.get(i), resultValue);
		}
		return result;
	}
	
}

 

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值