决策树之ID3算法java实现

package com.decisiontree;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;

public class ID3 {

	/**
	 * @param Spring_LGF
	 */
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		//用于存储所有属性可能的取值
		ArrayList<String> attrOutlook = new ArrayList<String>();
		attrOutlook.add("sunny");
		attrOutlook.add("overcast");
		attrOutlook.add("rainy");
		ArrayList<String> attrTemperature = new ArrayList<String>();
		attrTemperature.add("hot");
		attrTemperature.add("mild");
		attrTemperature.add("cool");
		ArrayList<String> attrHumidity = new ArrayList<String>();
		attrHumidity.add("high");
		attrHumidity.add("normal");
		ArrayList<String> attrWindy = new ArrayList<String>();
		attrWindy.add("true");
		attrWindy.add("false");
		ArrayList<String> attrPlay = new ArrayList<String>();
		attrPlay.add("no");
		attrPlay.add("yes");
		
		//属性名与属性的取值进行对应	 
		HashMap<String,ArrayList<String>> attr = new HashMap<String,ArrayList<String>>();
		attr.put("outlook", attrOutlook);
		attr.put("trmperature", attrTemperature);
		attr.put("humidity", attrHumidity);
		attr.put("windy", attrWindy);
		//attr.put("play",attrPlay);
		
		//存储属性的索引, 便于在对数据统计
		HashMap<String,Integer> attrIndex = new HashMap<String,Integer>();
		attrIndex.put("outlook", 0);
		attrIndex.put("trmperature", 1);
		attrIndex.put("humidity", 2);
		attrIndex.put("windy", 3);
		//attrIndex.put("play", 4);
		
		//样本存储
		
		String[][] data = {{"sunny","hot","high","false","no"},
				{"sunny","hot","high","true","no"},
				{"overcast","hot","high","false","yes"},
				{"rainy","mild","high","false","yes"},
				{"rainy","cool","normal","false","yes"},
				{"rainy","cool","normal","true","no"},
				{"overcast","cool","normal","true","yes"},
				{"sunny","mild","high","false","no"},
				{"sunny","cool","normal","false","yes"},
				{"rainy","mild","normal","false","yes"},
				{"sunny","mild","normal","true","yes"},
				{"overcast","mild","high","true","yes"},
				{"overcast","hot","normal","false","yes"},
				{"rainy","mild","high","true","no"}};
		ID3Tree root = new ID3Tree();
		buildID3Tree(root,data,attr,attrIndex);
		outputID3Tree(root);
	}
	//构造决策树
	public static ID3Tree buildID3Tree(ID3Tree root, String[][] data, 
			HashMap<String,ArrayList<String>> attr, HashMap<String,Integer> attrIndex){
		Iterator<String>  attrIt = attr.keySet().iterator();
		String maxAttr = null;
		String attrName;//属性名称
		HashMap<String, Double> attrValueList = new HashMap<String, Double>();
		//用于记录每一个属性的取值在样本中出现的次数
		HashMap<String, Double> attrValueMap = new HashMap<String,Double>();
		while(attrIt.hasNext() && (!attr.isEmpty())){
			attrName = attrIt.next();
			//取得属性可能出现的取值列表
			ArrayList<String> attrList = attr.get(attrName);
			//取得属性的索引值
			int index = attrIndex.get(attrName);
			//用于扫描每一个属性的所有取值
			for(int i = 0; i < attrList.size(); i++){
				String attrValue = attrList.get(i);
				int isPlay = 0;
				int noPlay = 0;
				//扫描书样本中每一个属性的取值出现的次数
				for(int j = 0; j < data.length; j++){
					if(data[j][index] == null){
						break;
					}
					if(data[j][index].equals(attrValue) && data[j][4].equals("yes")){
						isPlay++;
					}
					if(data[j][index].equals(attrValue) && data[j][4].equals("no")){
						noPlay++;
					}
				}
				double num = (-1* log(((double)isPlay/(double)(isPlay+noPlay)),2.0) * ((double)isPlay/(double)(isPlay+noPlay))) - log(((double)noPlay/(double)(isPlay+noPlay)),2.0) * ((double)noPlay/(double)(isPlay+noPlay));
				//double num = ((-1)*(Math.log(isPlay/(isPlay+noPlay)) / Math.log(2.0) * isPlay / (isPlay+noPlay)) - (Math.log(noPlay/(isPlay+noPlay)) / Math.log(2.0) * noPlay / (isPlay+noPlay)));
				double sum = 0.0;
				if(Double.compare(num, Double.NaN) == 0){
					num = 0.0;
				}
				attrValueMap.put(attrValue, num);
				//计算每一个属性的熵值
				if(attrValueList.get(attrName) == null){
					attrValueList.put(attrName, num*(double)(isPlay+noPlay)/data.length);
				}
				else{
					 sum = attrValueList.get(attrName) + num*(double)(isPlay+noPlay)/data.length;
					 attrValueList.put(attrName, sum);
				}
			}
			if(maxAttr == null){
				maxAttr = attrName;
			}
			else{
				if(attrValueList.get(attrName) - attrValueList.get(maxAttr) < 0.0){
					maxAttr = attrName;
				}
			}
		}
		if(maxAttr != null){
			int index = attrIndex.get(maxAttr);
			ArrayList<String> attrList = attr.get(maxAttr);
			root.attrName = maxAttr;
			root.treeList = new ArrayList<ID3Tree>();
			for(int i = 0; i < attrList.size(); i++){
				String valueName = attrList.get(i);
				double value = attrValueMap.get(valueName);
				ID3Tree node = new ID3Tree();
				int isPlay = 0;
				int isAttr = 0;
				for(int j = 0; j < data.length; j++){
					if(data[j][index] == null){
						break;
					}
					if(data[j][index].equals(valueName)){
						isAttr++;
						if(data[j][4].equals("yes")){
							isPlay++;
						}
					}
				}
				if(value == 0.0){
					node.isleaf = true;
					if(isPlay == isAttr){
						node.isPlay = true;
					}
					node.attrValue = valueName;
					root.treeList.add(node);
				}
				else{
					node.isleaf = false;
					node.attrValue = valueName;
					String [][]da= new String[14][4];
					for(int k = 0, n = 0; k < data.length; k++){
						if(data[k][index].equals(valueName)){
							da[n++] = data[k];
						}
					}
					HashMap<String,ArrayList<String>> attr2 = attr;
					attr2.remove(maxAttr);
					System.out.println(attr2);
					buildID3Tree(node,da,attr2,attrIndex);
					root.treeList.add(node);
				}
			}
			
		}
		return root;
	}
	//遍历决策树
	public static void outputID3Tree(ID3Tree root){
		System.out.println(root.attrName + "   " + root.attrValue + "  " + root.isPlay + "   " + root.isleaf);
		ArrayList<ID3Tree> treeList = root.treeList;
		if(root.treeList != null){
			for(int i = 0 ; i < treeList.size(); i++){
				outputID3Tree(treeList.get(i));
			}
		}
	}
	//对数的计算,第一个参数表示的对数,第二个参数表示的是底
	static public double log(double value, double base) {
		return Math.log(value) / Math.log(base);
	}

	
	static class ID3Tree{
		//是否是叶子节点
		private boolean isleaf;
		//是否出去玩,该值只有在叶子节点中出现
		private boolean isPlay;
		//上一个节点在该节点的取值
		private String attrValue;
		//孩子节点数组
		private ArrayList<ID3Tree> treeList;
		private String attrName;
		public String getAttrName() {
			return attrName;
		}
		public void setAttrName(String attrName) {
			this.attrName = attrName;
		}
		public boolean isPlay() {
			return isPlay;
		}
		public void setPlay(boolean isPlay) {
			this.isPlay = isPlay;
		}
		public boolean isIsleaf() {
			return isleaf;
		}
		public void setIsleaf(boolean isleaf) {
			this.isleaf = isleaf;
		}
		public ArrayList<ID3Tree> getTreeList() {
			return treeList;
		}
		public void setTreeList(ArrayList<ID3Tree> treeList) {
			this.treeList = treeList;
		}
		public String getAttrValue() {
			return attrValue;
		}
		public void setAttrValue(String attrValue) {
			this.attrValue = attrValue;
		}
		
	}
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值