java实现决策树ID3算法(文件读取)

package DecisionTree;

import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;

/**
 * 决策树算法测试类
 * 
 * @author Pjq
 * @qq 378290226
 * @mail 378290226@qq.com
 * @date 2012.04.17
 */
public class TestDecisionTree {

	/**
	 * 读取候选属性
	 * 
	 * @return 候选属性集合
	 * @throws IOException
	 */
	// 记录数组,记录从文件中读取的数据(redFileRecord[0][]为候选属性)
	String redFileRecord[][] = new String[100][];
	int length = 0; // 记录数
	FileInputStream file1;

	public ArrayList<String> readCandAttr() throws IOException {
		ArrayList<String> candAttr = new ArrayList<String>();

		try {
			// file1 = new FileInputStream("决策树数据.txt");
			file1 = new FileInputStream("数据挖掘数据--玩或学习.txt");
			InputStreamReader isr = new InputStreamReader(file1);
			BufferedReader bfr = new BufferedReader(isr);

			String s = ""; // 储存从文件中读取的一行记录
			String sSplit[] = new String[1000]; // 存储分隔好的数据
			while ((s = bfr.readLine()) != null) {
				sSplit = s.toString().trim().split(" ");
				for (int j = 0; j < sSplit.length; j++) {
					candAttr.add(sSplit[j]);
				}
				break;
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return candAttr;
	}

	/**
	 * 读取训练元组
	 * 
	 * @return 训练元组集合
	 * @throws IOException
	 */
	public ArrayList<ArrayList<String>> readData() throws IOException {
		ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
		try {
			// file1 = new FileInputStream("决策树数据.txt");
			file1 = new FileInputStream("数据挖掘数据--玩或学习.txt");
			InputStreamReader isr = new InputStreamReader(file1);
			BufferedReader bfr = new BufferedReader(isr);
			
			String s = bfr.readLine(); // 储存从文件中读取的一行记录
			String sSplit[] = new String[1000]; // 存储分隔好的数据
			while ((s = bfr.readLine()) != null) {
				sSplit = s.toString().trim().split(" ");
				ArrayList<String> sA = new ArrayList<String>();
				for (int j = 0; j < sSplit.length; j++) {
					sA.add(sSplit[j]);
				}
				datas.add(sA);
			}
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return datas;
	}

	/**
	 * 递归打印树结构
	 * 
	 * @param root
	 *            当前待输出信息的结点
	 */
	public void printTree(TreeNode root, int level) {
		System.out.println(root.getName());
		ArrayList<String> rules = root.getRule();

		ArrayList<TreeNode> children = root.getChild();
		for (int i = 0; i < rules.size(); i++) {
			for (int j = 0; j <= level; j++)
				System.out.print("     ");
			System.out.print(rules.get(i) + "--> ");
			printTree(children.get(i), (level + 1));
		}

	}

	/**
	 * 主函数,程序入口
	 * 
	 * @param args
	 */
	public static void main(String[] args) {
		TestDecisionTree tdt = new TestDecisionTree();
		ArrayList<String> candAttr = null; // 存放候选属性
		ArrayList<ArrayList<String>> datas = null;

		try {
			candAttr = tdt.readCandAttr();
			datas = tdt.readData();
		} catch (IOException e) {
			e.printStackTrace();
		}
		DecisionTree tree = new DecisionTree();
		TreeNode root = tree.buildTree(datas, candAttr);
		tdt.printTree(root, 0);
	}

}
package DecisionTree;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
 * 决策树构造类
 * 
 * @author Pjq
 * @qq 378290226
 * @mail 378290226@qq.com
 * @date 2012.04.17
 */
public class DecisionTree {
	private Integer attrSelMode; // 最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2

	public DecisionTree() {
		this.attrSelMode = Integer.valueOf(1);
	}

	public DecisionTree(int attrSelMode) {
		this.attrSelMode = Integer.valueOf(attrSelMode);
	}

	public void setAttrSelMode(Integer attrSelMode) {
		this.attrSelMode = attrSelMode;
	}

	/**
	 * 获取指定数据集中的类别及其计数
	 * 
	 * @param datas
	 *            指定的数据集
	 * @return 类别及其计数的map
	 */
	public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas) {
		Map<String, Integer> classes = new HashMap<String, Integer>();
		String c = "";
		ArrayList<String> tuple = null;
		for (int i = 0; i < datas.size(); i++) {
			tuple = datas.get(i);
			c = tuple.get(tuple.size() - 1);

			if (classes.containsKey(c)) { // 如果已经有该属性,属性值加1,否则新建
				classes.put(c, classes.get(c) + 1);
			} else {
				classes.put(c, 1);
			}
		}
		return classes;
	}

	/**
	 * 获取具有最大计数的类名,即求多数类
	 * 
	 * @param classes
	 *            类的键值集合
	 * @return 多数类的类名
	 */
	public String maxClass(Map<String, Integer> classes) {
		String maxC = "";
		int max = -1;
		Iterator iter = classes.entrySet().iterator();
		for (int i = 0; iter.hasNext(); i++) {
			Map.Entry entry = (Map.Entry) iter.next();
			String key = (String) entry.getKey();
			Integer val = (Integer) entry.getValue();
			if (val > max) {
				max = val;
				maxC = key;
			}
		}
		return maxC;
	}

	/**
	 * 构造决策树
	 * 
	 * @param datas
	 *            训练元组集合
	 * @param attrList
	 *            候选属性集合
	 * @return 决策树根结点
	 */
	public TreeNode buildTree(ArrayList<ArrayList<String>> datas,
			ArrayList<String> attrList) {
		TreeNode node = new TreeNode();
		node.setDatas(datas);
		node.setCandAttr(attrList);
		Map<String, Integer> classes = classOfDatas(datas); // 获取指定数据集中的类别及其计数
		if (classes.size() < 2) {
			Iterator iter = classes.entrySet().iterator();
			Map.Entry entry = (Map.Entry) iter.next();
			String name = entry.getKey().toString();
			node.setName(name);
			return node;
		}
		Gain gain = new Gain(datas, attrList);
		double styWhoEx = gain
				.getStylebookWholeExpection(classes, datas.size()); // 样本整体期望值
		int bestAttrIndex = gain.bestGainAttrIndex(styWhoEx); // 获取最佳分裂属性
		ArrayList<String> rules = gain.getValues(datas, bestAttrIndex); // 获取最佳侯选属性列上的值域
		node.setRule(rules); // 设置节点的分裂规则
		node.setName(attrList.get(bestAttrIndex)); // 设置最佳分裂属性的名称
		if (rules.size() > 2) { // ?此处有待商榷
			attrList.remove(bestAttrIndex);
		}

		// 按照分出的子集,再进行信息熵的计算再进行划分,一直到叶结点或到规定层
		for (int i = 0; i < rules.size(); i++) {
			String rule = rules.get(i);
			ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex,
					rule);
			for (int j = 0; j < di.size(); j++) {
				di.get(j).remove(bestAttrIndex);
			}
			if (di.size() == 0) {
				TreeNode leafNode = new TreeNode();
				// leafNode.setName(maxC);
				leafNode.setDatas(di);
				leafNode.setCandAttr(attrList);
				node.getChild().add(leafNode);
			} else {
				TreeNode newNode = buildTree(di, attrList);
				node.getChild().add(newNode);
			}
		}
		return node;
	}
}
package DecisionTree;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.math.BigDecimal;
import static java.lang.Math.*;

/**
 * 选择最佳分裂属性
 * 
 * @author Pjq
 * @qq 378290226
 * @mail 378290226@qq.com
 * @date 2012.04.17
 */
public class Gain {
	private ArrayList<ArrayList<String>> D = null; // 训练元组
	private ArrayList<String> attrList = null; // 候选属性集

	public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) {
		this.D = datas;
		this.attrList = attrList;
	}

	/**
	 * 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的)
	 * 
	 * @param attrIndex
	 *            指定的属性列的索引
	 * @return 值域集合
	 */
	public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas,
			int attrIndex) {
		ArrayList<String> values = new ArrayList<String>();
		String r = "";
		for (int i = 0; i < datas.size(); i++) {
			r = datas.get(i).get(attrIndex);
			if (!values.contains(r)) {
				values.add(r);
			}
		}
		return values;
	}

	/**
	 * 获取指定数据集中指定属性列索引的域值及其计数
	 * 
	 * @param d
	 *            指定的数据集
	 * @param attrIndex
	 *            指定的属性列索引
	 * @return 类别及其计数的map
	 */
	public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas,
			int attrIndex) {
		Map<String, Integer> valueCount = new HashMap<String, Integer>();
		String c = "";
		ArrayList<String> tuple = null;
		for (int i = 0; i < datas.size(); i++) {
			tuple = datas.get(i);
			c = tuple.get(attrIndex);
			if (valueCount.containsKey(c)) {
				valueCount.put(c, valueCount.get(c) + 1);
			} else {
				valueCount.put(c, 1);
			}
		}
		return valueCount;
	}

	/**
	 * 获取指定属性列上指定值域的所有元组
	 * 
	 * @param attrIndex
	 *            指定属性列索引
	 * @param value
	 *            指定属性列的值域
	 * @return 指定属性列上指定值域的所有元组
	 */
	public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value) {
		ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>();
		ArrayList<String> t = null;
		for (int i = 0; i < D.size(); i++) {
			t = D.get(i);
			if (t.get(attrIndex).equals(value)) {
				Di.add(t);
			}
		}
		return Di;
	}

	/**
	 * 基于按指定属性划分对D的元组分类所需要的期望信息
	 * 
	 * @param attrIndex
	 *            指定属性的索引
	 * @return 按指定属性划分的期望信息值
	 */
	public double infoAttr(int attrIndex) {
		double info = 0.000;
		ArrayList<String> values = getValues(D, attrIndex);
		DecisionTree dt = new DecisionTree();
		Map<String, Integer> classes; // 获取候选属性中一个取值的(age-> youth-> yes:no)
		double n1 = 0.000;
		for (int i = 0; i < values.size(); i++) {
			double e = 0.0, f = 0.0;
			ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values
					.get(i));
			classes = dt.classOfDatas(dv);
			n1 = ((double) dv.size()) / ((double) D.size());
			try {
				/*
				 * e = (double)classes.get("yes"); f =
				 * (double)classes.get("no");
				 */
				e = (double) classes.get("Play");
				f = (double) classes.get("Study");
			} catch (Exception exce) {

			}

			info += n1 * gerException(e, f);
		}
		return info;
	}

	/**
	 * 获取最佳分裂属性的索引
	 * 
	 * @return 最佳分裂属性的索引
	 */
	public int bestGainAttrIndex(double styWhoEx) {
		int index = -1;
		double gain = 0.000;
		double tempGain = 0.000;
		for (int i = 0; i < attrList.size(); i++) {
			tempGain = styWhoEx - infoAttr(i);
			if (tempGain > gain) {
				gain = tempGain;
				index = i;
			}
		}
		return index;
	}

	/**
	 * 获取样本整体期望值
	 * 
	 * @return 样本整体期望值
	 */
	public double getStylebookWholeExpection(Map<String, Integer> classes, int n) {
		double styWhoEx = 0.0;
		Iterator iter = classes.entrySet().iterator();
		for (int i = 0; iter.hasNext(); i++) {
			Map.Entry entry = (Map.Entry) iter.next();
			Integer val = (Integer) entry.getValue();
			double vn = (double) val / (double) n;
			styWhoEx += -(vn) * ((log((double) vn) / (log((double) 2))));
		}
		return styWhoEx;
	}

	/**
	 * 计算属性期望值
	 * 
	 * @return 最佳分裂属性的索引
	 */
	private double gerException(double e, double f) {
		double info = 0.0000;
		if (e == 0.0 || f == 0.0) {
			info = 0.0;
			return info;
		} else if (e == f) {
			info = 1.0;
			return info;
		} else {
			double sum = e + f;
			info = -(e / sum) * ((log((double) (e / sum)) / (log((double) 2))))
					- (f / sum)
					* ((log((double) (f / sum)) / (log((double) 2))));
		}
		return info;
	}

}

package DecisionTree;  
import java.util.ArrayList;  
/** 
 * 决策树结点类 
 * @author pjq 
 * @qq 378290226 
 * @data 2011.03.15 
 */  
public class TreeNode {  
    private String name; //节点名(分裂属性的名称)   
    private ArrayList<String> rule; //结点的分裂规则   
    ArrayList<TreeNode> child; //子结点集合   
    private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组   
    private ArrayList<String> candAttr; //划分到该结点的候选属性   
    public TreeNode() {  
        this.name = "";  
        this.rule = new ArrayList<String>();  
        this.child = new ArrayList<TreeNode>();  
        this.datas = null;  
        this.candAttr = null;  
    }  
    public ArrayList<TreeNode> getChild() {  
        return child;  
    }  
    public void setChild(ArrayList<TreeNode> child) {  
        this.child = child;  
    }  
    public ArrayList<String> getRule() {  
        return rule;  
    }  
    public void setRule(ArrayList<String> rule) {  
        this.rule = rule;  
    }  
    public String getName() {  
        return name;  
    }  
    public void setName(String name) {  
        this.name = name;  
    }  
    public ArrayList<ArrayList<String>> getDatas() {  
        return datas;  
    }  
    public void setDatas(ArrayList<ArrayList<String>> datas) {  
        this.datas = datas;  
    }  
    public ArrayList<String> getCandAttr() {  
        return candAttr;  
    }  
    public void setCandAttr(ArrayList<String> candAttr) {  
        this.candAttr = candAttr;  
    }  
}  
转载: java实现决策树ID3算法(文件读取)


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值