package DecisionTree;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
/**
* 决策树算法测试类
*
* @author Pjq
* @qq 378290226
* @mail 378290226@qq.com
* @date 2012.04.17
*/
public class TestDecisionTree {
/**
* 读取候选属性
*
* @return 候选属性集合
* @throws IOException
*/
// 记录数组,记录从文件中读取的数据(redFileRecord[0][]为候选属性)
String redFileRecord[][] = new String[100][];
int length = 0; // 记录数
FileInputStream file1;
public ArrayList<String> readCandAttr() throws IOException {
ArrayList<String> candAttr = new ArrayList<String>();
try {
// file1 = new FileInputStream("决策树数据.txt");
file1 = new FileInputStream("数据挖掘数据--玩或学习.txt");
InputStreamReader isr = new InputStreamReader(file1);
BufferedReader bfr = new BufferedReader(isr);
String s = ""; // 储存从文件中读取的一行记录
String sSplit[] = new String[1000]; // 存储分隔好的数据
while ((s = bfr.readLine()) != null) {
sSplit = s.toString().trim().split(" ");
for (int j = 0; j < sSplit.length; j++) {
candAttr.add(sSplit[j]);
}
break;
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return candAttr;
}
/**
* 读取训练元组
*
* @return 训练元组集合
* @throws IOException
*/
public ArrayList<ArrayList<String>> readData() throws IOException {
ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
try {
// file1 = new FileInputStream("决策树数据.txt");
file1 = new FileInputStream("数据挖掘数据--玩或学习.txt");
InputStreamReader isr = new InputStreamReader(file1);
BufferedReader bfr = new BufferedReader(isr);
String s = bfr.readLine(); // 储存从文件中读取的一行记录
String sSplit[] = new String[1000]; // 存储分隔好的数据
while ((s = bfr.readLine()) != null) {
sSplit = s.toString().trim().split(" ");
ArrayList<String> sA = new ArrayList<String>();
for (int j = 0; j < sSplit.length; j++) {
sA.add(sSplit[j]);
}
datas.add(sA);
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return datas;
}
/**
* 递归打印树结构
*
* @param root
* 当前待输出信息的结点
*/
public void printTree(TreeNode root, int level) {
System.out.println(root.getName());
ArrayList<String> rules = root.getRule();
ArrayList<TreeNode> children = root.getChild();
for (int i = 0; i < rules.size(); i++) {
for (int j = 0; j <= level; j++)
System.out.print(" ");
System.out.print(rules.get(i) + "--> ");
printTree(children.get(i), (level + 1));
}
}
/**
* 主函数,程序入口
*
* @param args
*/
public static void main(String[] args) {
TestDecisionTree tdt = new TestDecisionTree();
ArrayList<String> candAttr = null; // 存放候选属性
ArrayList<ArrayList<String>> datas = null;
try {
candAttr = tdt.readCandAttr();
datas = tdt.readData();
} catch (IOException e) {
e.printStackTrace();
}
DecisionTree tree = new DecisionTree();
TreeNode root = tree.buildTree(datas, candAttr);
tdt.printTree(root, 0);
}
}
package DecisionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
/**
* 决策树构造类
*
* @author Pjq
* @qq 378290226
* @mail 378290226@qq.com
* @date 2012.04.17
*/
public class DecisionTree {
private Integer attrSelMode; // 最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2
public DecisionTree() {
this.attrSelMode = Integer.valueOf(1);
}
public DecisionTree(int attrSelMode) {
this.attrSelMode = Integer.valueOf(attrSelMode);
}
public void setAttrSelMode(Integer attrSelMode) {
this.attrSelMode = attrSelMode;
}
/**
* 获取指定数据集中的类别及其计数
*
* @param datas
* 指定的数据集
* @return 类别及其计数的map
*/
public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas) {
Map<String, Integer> classes = new HashMap<String, Integer>();
String c = "";
ArrayList<String> tuple = null;
for (int i = 0; i < datas.size(); i++) {
tuple = datas.get(i);
c = tuple.get(tuple.size() - 1);
if (classes.containsKey(c)) { // 如果已经有该属性,属性值加1,否则新建
classes.put(c, classes.get(c) + 1);
} else {
classes.put(c, 1);
}
}
return classes;
}
/**
* 获取具有最大计数的类名,即求多数类
*
* @param classes
* 类的键值集合
* @return 多数类的类名
*/
public String maxClass(Map<String, Integer> classes) {
String maxC = "";
int max = -1;
Iterator iter = classes.entrySet().iterator();
for (int i = 0; iter.hasNext(); i++) {
Map.Entry entry = (Map.Entry) iter.next();
String key = (String) entry.getKey();
Integer val = (Integer) entry.getValue();
if (val > max) {
max = val;
maxC = key;
}
}
return maxC;
}
/**
* 构造决策树
*
* @param datas
* 训练元组集合
* @param attrList
* 候选属性集合
* @return 决策树根结点
*/
public TreeNode buildTree(ArrayList<ArrayList<String>> datas,
ArrayList<String> attrList) {
TreeNode node = new TreeNode();
node.setDatas(datas);
node.setCandAttr(attrList);
Map<String, Integer> classes = classOfDatas(datas); // 获取指定数据集中的类别及其计数
if (classes.size() < 2) {
Iterator iter = classes.entrySet().iterator();
Map.Entry entry = (Map.Entry) iter.next();
String name = entry.getKey().toString();
node.setName(name);
return node;
}
Gain gain = new Gain(datas, attrList);
double styWhoEx = gain
.getStylebookWholeExpection(classes, datas.size()); // 样本整体期望值
int bestAttrIndex = gain.bestGainAttrIndex(styWhoEx); // 获取最佳分裂属性
ArrayList<String> rules = gain.getValues(datas, bestAttrIndex); // 获取最佳侯选属性列上的值域
node.setRule(rules); // 设置节点的分裂规则
node.setName(attrList.get(bestAttrIndex)); // 设置最佳分裂属性的名称
if (rules.size() > 2) { // ?此处有待商榷
attrList.remove(bestAttrIndex);
}
// 按照分出的子集,再进行信息熵的计算再进行划分,一直到叶结点或到规定层
for (int i = 0; i < rules.size(); i++) {
String rule = rules.get(i);
ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex,
rule);
for (int j = 0; j < di.size(); j++) {
di.get(j).remove(bestAttrIndex);
}
if (di.size() == 0) {
TreeNode leafNode = new TreeNode();
// leafNode.setName(maxC);
leafNode.setDatas(di);
leafNode.setCandAttr(attrList);
node.getChild().add(leafNode);
} else {
TreeNode newNode = buildTree(di, attrList);
node.getChild().add(newNode);
}
}
return node;
}
}
package DecisionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.math.BigDecimal;
import static java.lang.Math.*;
/**
* 选择最佳分裂属性
*
* @author Pjq
* @qq 378290226
* @mail 378290226@qq.com
* @date 2012.04.17
*/
public class Gain {
private ArrayList<ArrayList<String>> D = null; // 训练元组
private ArrayList<String> attrList = null; // 候选属性集
public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) {
this.D = datas;
this.attrList = attrList;
}
/**
* 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的)
*
* @param attrIndex
* 指定的属性列的索引
* @return 值域集合
*/
public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas,
int attrIndex) {
ArrayList<String> values = new ArrayList<String>();
String r = "";
for (int i = 0; i < datas.size(); i++) {
r = datas.get(i).get(attrIndex);
if (!values.contains(r)) {
values.add(r);
}
}
return values;
}
/**
* 获取指定数据集中指定属性列索引的域值及其计数
*
* @param d
* 指定的数据集
* @param attrIndex
* 指定的属性列索引
* @return 类别及其计数的map
*/
public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas,
int attrIndex) {
Map<String, Integer> valueCount = new HashMap<String, Integer>();
String c = "";
ArrayList<String> tuple = null;
for (int i = 0; i < datas.size(); i++) {
tuple = datas.get(i);
c = tuple.get(attrIndex);
if (valueCount.containsKey(c)) {
valueCount.put(c, valueCount.get(c) + 1);
} else {
valueCount.put(c, 1);
}
}
return valueCount;
}
/**
* 获取指定属性列上指定值域的所有元组
*
* @param attrIndex
* 指定属性列索引
* @param value
* 指定属性列的值域
* @return 指定属性列上指定值域的所有元组
*/
public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value) {
ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>();
ArrayList<String> t = null;
for (int i = 0; i < D.size(); i++) {
t = D.get(i);
if (t.get(attrIndex).equals(value)) {
Di.add(t);
}
}
return Di;
}
/**
* 基于按指定属性划分对D的元组分类所需要的期望信息
*
* @param attrIndex
* 指定属性的索引
* @return 按指定属性划分的期望信息值
*/
public double infoAttr(int attrIndex) {
double info = 0.000;
ArrayList<String> values = getValues(D, attrIndex);
DecisionTree dt = new DecisionTree();
Map<String, Integer> classes; // 获取候选属性中一个取值的(age-> youth-> yes:no)
double n1 = 0.000;
for (int i = 0; i < values.size(); i++) {
double e = 0.0, f = 0.0;
ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values
.get(i));
classes = dt.classOfDatas(dv);
n1 = ((double) dv.size()) / ((double) D.size());
try {
/*
* e = (double)classes.get("yes"); f =
* (double)classes.get("no");
*/
e = (double) classes.get("Play");
f = (double) classes.get("Study");
} catch (Exception exce) {
}
info += n1 * gerException(e, f);
}
return info;
}
/**
* 获取最佳分裂属性的索引
*
* @return 最佳分裂属性的索引
*/
public int bestGainAttrIndex(double styWhoEx) {
int index = -1;
double gain = 0.000;
double tempGain = 0.000;
for (int i = 0; i < attrList.size(); i++) {
tempGain = styWhoEx - infoAttr(i);
if (tempGain > gain) {
gain = tempGain;
index = i;
}
}
return index;
}
/**
* 获取样本整体期望值
*
* @return 样本整体期望值
*/
public double getStylebookWholeExpection(Map<String, Integer> classes, int n) {
double styWhoEx = 0.0;
Iterator iter = classes.entrySet().iterator();
for (int i = 0; iter.hasNext(); i++) {
Map.Entry entry = (Map.Entry) iter.next();
Integer val = (Integer) entry.getValue();
double vn = (double) val / (double) n;
styWhoEx += -(vn) * ((log((double) vn) / (log((double) 2))));
}
return styWhoEx;
}
/**
* 计算属性期望值
*
* @return 最佳分裂属性的索引
*/
private double gerException(double e, double f) {
double info = 0.0000;
if (e == 0.0 || f == 0.0) {
info = 0.0;
return info;
} else if (e == f) {
info = 1.0;
return info;
} else {
double sum = e + f;
info = -(e / sum) * ((log((double) (e / sum)) / (log((double) 2))))
- (f / sum)
* ((log((double) (f / sum)) / (log((double) 2))));
}
return info;
}
}
package DecisionTree;
import java.util.ArrayList;
/**
* 决策树结点类
* @author pjq
* @qq 378290226
* @data 2011.03.15
*/
public class TreeNode {
private String name; //节点名(分裂属性的名称)
private ArrayList<String> rule; //结点的分裂规则
ArrayList<TreeNode> child; //子结点集合
private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组
private ArrayList<String> candAttr; //划分到该结点的候选属性
public TreeNode() {
this.name = "";
this.rule = new ArrayList<String>();
this.child = new ArrayList<TreeNode>();
this.datas = null;
this.candAttr = null;
}
public ArrayList<TreeNode> getChild() {
return child;
}
public void setChild(ArrayList<TreeNode> child) {
this.child = child;
}
public ArrayList<String> getRule() {
return rule;
}
public void setRule(ArrayList<String> rule) {
this.rule = rule;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public ArrayList<ArrayList<String>> getDatas() {
return datas;
}
public void setDatas(ArrayList<ArrayList<String>> datas) {
this.datas = datas;
}
public ArrayList<String> getCandAttr() {
return candAttr;
}
public void setCandAttr(ArrayList<String> candAttr) {
this.candAttr = candAttr;
}
}
转载:
java实现决策树ID3算法(文件读取)