1、ResultNode.java:
package com.decision.my;
public class ResultNode {
/** 属性名 */
private String name;
/** 属性值 */
private String value;
/** 类别 */
private String classfy;
public String getClassfy() {
return classfy;
}
public void setClassfy(String classfy) {
this.classfy = classfy;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
}
2、TreeNode.java
package com.decision.my;
import java.util.ArrayList;
import java.util.List;
/**
* <p>本类描述: 树节点</p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-23 下午07:22:21
*/
public class TreeNode {
/** 树节点的名字 */
private String name;
/** 树节点的信息熵 */
private Double value;
/** 树叶节点的分类 */
private String classify;
/** 树节点的孩子 */
private List<TreeNode> childs = new ArrayList<TreeNode>();
/** 树节点的父亲 */
private TreeNode parent;
private List<String> path = new ArrayList<String>();
public List<String> getPath() {
return path;
}
public void setPath(List<String> path) {
this.path = path;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Double getValue() {
return value;
}
public void setValue(Double value) {
this.value = value;
}
public String getClassify() {
return classify;
}
public void setClassify(String classify) {
this.classify = classify;
}
public List<TreeNode> getChilds() {
return childs;
}
public void setChilds(List<TreeNode> childs) {
this.childs = childs;
}
public TreeNode getParent() {
return parent;
}
public void setParent(TreeNode parent) {
this.parent = parent;
}
}
3、Decision.java
package com.decision.my;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;
/**
* <p>本类描述: </p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-23 下午07:21:17
*/
public class Decision {
/**
* 方法描述:创建决策树
* @param D
* @param attributeList
* @return
*/
public static TreeNode createDecisionTree(List<ArrayList<String>> D, List<String> attributeList) {
TreeNode root = new TreeNode();
String classfy = pure(D); // 判断D是否是纯的,若不是,返回null;若是,返回类别
if (classfy != null) {
root.setClassify(classfy);
return root;
} else if (attributeList == null || attributeList.size() == 0) {
root.setClassify("more");
return root;
} else {
// 得到当前信息增益最大的节点
Double gainD = getGainD(D);
Double max = 0D;
Map<String, Integer> mapAttribute = new HashMap<String, Integer>();
int ind = 0;
for (int i = 0; i < attributeList.size(); i++) {
Map<String, Integer> map = new HashMap<String, Integer>();
for (ArrayList<String> lists : D) {
Integer value = map.get(lists.get(i));
map.put(lists.get(i), value == null ? 1 : ++value);
}
Double gainA = getGainA(map, D, gainD, i);
if (gainA > max) {
max = gainA;
ind = i;
mapAttribute = map;
}
}
root.setName(attributeList.get(ind));
root.setValue(max);
ArrayList<String> newAttributeList = (ArrayList<String>)((ArrayList<String>)attributeList).clone();
newAttributeList.remove(ind);
for (Entry<String, Integer> entry : mapAttribute.entrySet()) {
List<ArrayList<String>> datas = getData(entry, D, ind);
TreeNode child = createDecisionTree(datas, newAttributeList);
child.setParent(root);
child.getPath().add(entry.getKey());
root.getChilds().add(child);
}
}
return root;
}
/**
* 方法描述:得到属性对应的数据集
* @param entry
* @param d
* @param ind
* @return
*/
private static List<ArrayList<String>> getData(Entry<String, Integer> entry, List<ArrayList<String>> D, int ind) {
List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
for (ArrayList<String> lists : D) {
if (entry.getKey().equals(lists.get(ind))) {
ArrayList<String> newLists = (ArrayList<String>)lists.clone();
newLists.remove(ind);
results.add(newLists);
}
}
return results;
}
/**
* 方法描述:得到每个属性的信息增益
* @param map
* @param D
* @param gainD
* @param i
* @return
*/
private static Double getGainA(Map<String, Integer> map, List<ArrayList<String>> D, Double gainD, int i) {
Double result = 0D;
Double info = 0D;
for (Entry<String, Integer> entry : map.entrySet()) {
double attributeCount = Double.parseDouble(String.valueOf(entry.getValue()));
double allCount = Double.parseDouble(String.valueOf(D.size()));
Map<String, Integer> mapAttribute1 = new HashMap<String, Integer>();
for (ArrayList<String> lists : D) {
if (entry.getKey().equals(lists.get(i))) {
Integer value = mapAttribute1.get(lists.get(lists.size() - 1));
mapAttribute1.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
}
}
Double item = 0D;
for (Entry<String, Integer> e : mapAttribute1.entrySet()) {
Double p = Double.parseDouble(String.valueOf(e.getValue())) / attributeCount;
item += -p * Math.log(p) / Math.log(2);
}
info += (attributeCount / allCount) * item;
}
result = gainD - info;
return result;
}
/**
* 方法描述:得到数据集D的信息增益
* @param D
* @return
*/
private static Double getGainD(List<ArrayList<String>> D) {
Double result = 0D;
Map<String, Integer> map = new HashMap<String, Integer>();
for (ArrayList<String> lists : D) {
Integer value = map.get(lists.get(lists.size() - 1));
map.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
}
for (Entry<String, Integer> entry : map.entrySet()) {
Double p = Double.parseDouble(String.valueOf(entry.getValue())) / Double.parseDouble(String.valueOf(D.size()));
result += -p * Math.log(p) / Math.log(2);
}
return result;
}
/**
* 方法描述:判断D是否是纯的,若不是,返回null;若是,返回类别
* @param d
* @return
*/
private static String pure(List<ArrayList<String>> D) {
Set<String> sets = new HashSet<String>();
for (ArrayList<String> lists : D) {
sets.add(lists.get(lists.size() - 1));
}
Iterator<String> iterator = sets.iterator();
if (sets.size() == 1) {
return iterator.next();
} else {
return null;
}
}
public static void main(String[] args) {
TreeNode node = createDecisionTree(configData(),configAttribute());
List<ArrayList<ResultNode>> nodes = getAllPath(node);
for (int i = 0; i < nodes.size(); i++) {
System.out.println("第" + (i + 1) + "条路径:");
ArrayList<ResultNode> lists = nodes.get(i);
for (int j = 0; j < lists.size() - 1; j++) {
System.out.print("(" + lists.get(j).getName() + ":" + lists.get(j).getValue() + "),");
}
System.out.print(lists.get(lists.size() - 1).getClassfy());
System.out.println();
}
}
/**
* 方法描述:得到指定树的所有叶子节点
* @param root
* @return
*/
private static List<TreeNode> getLeafs(TreeNode root) {
List<TreeNode> results = new ArrayList<TreeNode>();
traverseTree(root, results);
return results;
}
/**
* 方法描述:递归遍整个树
* @param node
*/
private static void traverseTree(TreeNode node, List<TreeNode> results) {
List<TreeNode> childs = node.getChilds();
if (childs == null || childs.size() == 0) {
results.add(node);
} else {
for (TreeNode child : childs) {
traverseTree(child, results);
}
}
}
/**
* 方法描述:遍历树,产生所有路径
* @param node
* @return
*/
private static List<ArrayList<ResultNode>> getAllPath(TreeNode root) {
List<ArrayList<ResultNode>> results = new ArrayList<ArrayList<ResultNode>>();
List<TreeNode> leafs = getLeafs(root);
for (TreeNode node : leafs) {
ArrayList<ResultNode> resultNodes = new ArrayList<ResultNode>();
ResultNode resultNode = new ResultNode();
resultNode.setClassfy(node.getClassify());
resultNodes.add(resultNode);
TreeNode parent = node.getParent();
TreeNode p = node;
while (parent != null) {
ResultNode resultNode1 = new ResultNode();
resultNode1.setName(parent.getName());
resultNode1.setValue(p.getPath().get(0));
resultNodes.add(resultNode1);
p = parent;
parent = parent.getParent();
}
Collections.reverse(resultNodes);
results.add(resultNodes);
}
return results;
}
private static List<String> configAttribute() {
List<String> results = new ArrayList<String>();
results.add("age");
results.add("income");
results.add("student");
results.add("credit_rating");
return results;
}
private static List<ArrayList<String>> configData() {
List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
try {
BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(new File("D:/data.txt"))));
String line = is.readLine();
while (line != null) {
String[] split = line.split(",");
ArrayList<String> s1s = new ArrayList<String>();
for (int i = 0; i < split.length; i++) {
s1s.add(split[i]);
}
results.add(s1s);
line = is.readLine();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return results;
}
}
4、测试
测试数据:d盘放一个data.txt,数据如下
youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no