主要实现
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class ID3 {
public static GTree<String> tree = new GTree<>();// 一颗通用树
public static String[] attribute;// 自变因素类别列表(outlook,temperature,humidity,windy)
public String[] valuename;// 因变因素列表(play)
public String[] value = new String[2];// 决策值(YES,NO)因变因素
public static List<String[]> data;// 数据
public static List<Set<String>> clomnSetList;// 自变因素
// 初始化数据
{
data = getData("test2.txt");
String[][] a = new String[data.size()][];
int ii = 0;
for (String[] s : data) {
a[ii++] = s;
}
// 初始化自变因素列表
clomnSetList = getClomnValueSet(a);
}
// 获取数据
/**
* 获得数据
*
* @param path
* 本地文件路径 (仅支持文本文件)
* @return List<String[]>
*/
public List<String[]> getData(String path) {
List<String[]> data = new ArrayList<>();
File f = new File(path);
FileReader fr;
try {
fr = new FileReader(f);
BufferedReader bfr = new BufferedReader(fr);
String[] firstLine = bfr.readLine().split(",");
attribute = new String[firstLine.length - 1];
valuename = new String[1];
// 初始化自变因素列表
for (int i = 0; i < firstLine.length - 1; i++) {
attribute[i] = firstLine[i];
}
// 初始因变因素
valuename[0] = firstLine[firstLine.length - 1];
// 初始化数据
String nextline;
while ((nextline = bfr.readLine()) != null) {
String[] d = nextline.split(",");
if (value[0] != d[d.length - 1] && value[1] != d[d.length - 1]) {
if (value[0] == null) {
value[0] = d[d.length - 1];
} else if (value[1] == null && (!d[d.length - 1].equals(value[0]))) {
value[1] = d[d.length - 1];
}
}
data.add(d);
}
} catch (Exception e) {
System.out.println("文件路径不正确");
}
return data;
}
// 获取自变因素一列数据项信息熵
/**
*
* @param data
* 数据源
* @param cloumn
* 一列自变因素
* @param index
* 列指针
* @return 一类自变因素信息熵
*/
public double getGainClomn(List<String[]> data, Set<String> cloumn, int index) {
double result = 0;
for (String s : cloumn) {
double entropy = getGain(data, s, index);
double probility = getpro(data, s, index);
result = result + entropy * probility;
}
return result;
}
// 获取一列自变因素各因素熵值,并进行排序
/**
*
* @param data
* 数据源
* @param cloumn
* 一列自变因素
* @param index
* 列指针
* @return key:value形式的二维矩阵
*/
public String[][] getGainClomnEntropy(List<String[]> data, Set<String> cloumn, int index) {
String[][] result = new String[cloumn.size()][2];
int i = 0;
for (String s : cloumn) {
int j = 0;
double value = getGain(data, s, index);
result[i][j++] = s;
result[i++][j] = "" + value;
}
// 排序
return sort(result);
}
// key:value形式二维数组排序方法
public String[][] sort(String[][] a) {
Arrays.sort(a, new Comparator<String[]>() {
@Override
public int compare(String[] o1, String[] o2) {
return Double.compare(Double.parseDouble(o2[1]), Double.parseDouble(o1[1]));
}
});
return a;
}
// 获取指定数据信息熵
/**
*
* @param data
* 数据源
* @param one
* 要求取熵的一个自变因素
* @param index
* 所在列指针
* @return 熵
*/
public double getGain(List<String[]> data, String one, int index) {
double count = 0;
int count1 = 0;
int count2 = 0;
for (String[] d : data) {
if (d[index].trim().equals(one.trim())) {
count++;
if (d[d.length - 1].trim().equals(value[0].trim())) {
count1++;
}
if (d[d.length - 1].trim().equals(value[1].trim())) {
count2++;
}
}
}
double probability1 = Double.parseDouble("" + count1) / count;// 决策1概率
double probability2 = Double.parseDouble("" + count2) / count;// 决策2概率
if (probability1 == 0) {
return 0;
}
if (probability1 == 1) {
return 1;
}
double result = -probability1 * (Math.log(probability1) / Math.log(2))
- probability2 * (Math.log(probability2) / Math.log(2));
return result;
}
// 获取指定自变因素概率
/**
* 获取指定自变因素概率
*
* @param data
* 数据源
* @param one
* 自变因素
* @param index
* 所在列指针
* @return 概率
*/
public double getpro(List<String[]> data, String one, int index) {
double count = data.size();
int count1 = 0;
for (String[] d : data) {
if (d[index].trim().equals(one.trim())) {
count1++;
}
}
return Double.parseDouble("" + count1) / count;
}
// 获取当前文件系统信息熵
/**
* 获取当前文件系统信息熵
*
* @param data
* 数据源(原始数据)
* @return 系统信息熵
*/
public double getGain(List<String[]> data) {
double count = data.size();
int count1 = 0;
int count2 = 0;
for (String[] d : data) {
if (d[d.length - 1].trim().equals(value[0].trim())) {
count1++;
}
if (d[d.length - 1].trim().equals(value[1].trim())) {
count2++;
}
}
double probability1 = Double.parseDouble("" + count1) / count;
double probability2 = Double.parseDouble("" + count2) / count;
double result = -probability1 * (Math.log(probability1) / Math.log(2))
- probability2 * (Math.log(probability2) / Math.log(2));
return result;
}
// 获取信息增益
/**
*
* @param data
* 数据源
* @param cloumn
* 一列自变因素
* @param index
* 列指针
* @return 信息增益
*/
public Double getGainCreat(List<String[]> data, Set<String> cloumn, int index) {
return getGain(data) - getGainClomn(data, cloumn, index);
}
// 获取当前列的数据有哪些
/**
* 初始化一列自变因素列表
*
* @param a数据源
* @return [{sunny, overcast, rainy},{.....},...]
*/
public List<Set<String>> getClomnValueSet(String[][] a) {
a = reverdraSort(a);
print(a);
List<Set<String>> list = new ArrayList<>();
for (int i = 0; i < a.length - 1; i++) {
Set<String> set = new HashSet<>();
for (int j = i; j < a[i].length; j++) {
set.add(a[i][j]);
}
list.add(set);
}
return list;
}
// 二维数组列行倒置排序法
public String[][] reverdraSort(String[][] a) {
int l1 = a.length;
int l2 = a[0].length;
String[][] a1 = new String[l2][l1];
for (int i = 0; i < l2; i++) {
for (int j = 0; j < l1; j++) {
a1[i][j] = a[j][i];
}
}
return a1;
}
// 打印二维数组方法
public static void print(String arr[][]) {
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
System.out.print(arr[i][j] + "、");
}
System.out.println();
}
System.out.println();
}
// 构建树
/**
* 递归构建决策树
*
* @param data1
* 数据源
* @param root1
* 当前根节点
* @param clomnSetList
* 自变因素矩阵
* @param attribute
* 自变因素类别列表
* @return
*/
public TreeNode<String> makeTree(List<String[]> data1, TreeNode<String> root1, List<Set<String>> clomnSetList,
String[] attribute) {
// 找信息熵最大的自变因素
if (clomnSetList.size() > 1) {
// System.out.println(root);
Double max = 0D;
int maxIndex = 0;// 信息熵最大因素下标
for (int i = 0; i < clomnSetList.size(); i++) {
double temp = getGainCreat(data1, clomnSetList.get(i), i);
// if(temp==0) {
// return;
// }
if (max < temp) {
max = temp;
maxIndex = i;
}
}
TreeNode<String> n1 = null;
if (root1 == null) {
n1 = new TreeNode<>(attribute[maxIndex], null);
root1 = n1;
tree.insert(null, root1);
} else {
n1 = new TreeNode<>(attribute[maxIndex], null);
tree.insert(root1, n1);
}
// 获取此自变因素的决策数组(熵数组)
String[][] device = getGainClomnEntropy(data1, clomnSetList.get(maxIndex), maxIndex);
for (int i = 0; i < device.length; i++) {
TreeNode<String> n = new TreeNode<>(device[i][0], null);
tree.insert(n1, n);
if (Double.parseDouble(device[i][1]) == 0) {
TreeNode<String> n2 = new TreeNode<>(value[1], null);
tree.insert(n, n2);
} else if (Double.parseDouble(device[i][1]) == 1) {
TreeNode<String> n2 = new TreeNode<>(value[0], null);
tree.insert(n, n2);
}
else {
// 重建数据
String v = device[i][0];
List<String[]> ndata1 = new ArrayList<>();
for (String[] s : data1) {
if (s[maxIndex].trim().equals(v.trim())) {
String[] nn = new String[s.length - 1];
for (int j = 0; j < maxIndex; j++) {
nn[j] = s[j];
}
for (int j = maxIndex; j < nn.length; j++) {
nn[j] = s[j + 1];
}
ndata1.add(nn);
}
}
String[] newa = new String[attribute.length - 1];
for (int k = 0; k < maxIndex; k++) {
newa[k] = attribute[k];
}
for (int k = maxIndex; k < newa.length; k++) {
newa[k] = attribute[k + 1];
}
List<Set<String>> clomnSetListnew = new ArrayList<>(clomnSetList);
clomnSetListnew.remove(clomnSetListnew.get(maxIndex));
makeTree(ndata1, n, clomnSetListnew, newa);
}
}
}
return root1;
}
public static void main(String[] args) {
ID3 id3 = new ID3();
TreeNode<String> treenode = id3.makeTree(data, null, clomnSetList, attribute);
tree.Travelsal(treenode, 1);
}
}
数据结构支持
import java.util.ArrayList;
import java.util.List;
//通用树的节点
public class TreeNode<T>{
private Object value;//数据区
private List<TreeNode<T>> childlist;//孩子节点指针集合
public TreeNode(){
value = null;
childlist = new ArrayList<>();
}
public TreeNode(Object value,List<TreeNode<T>> childList) {
this.value = value;
if(childList!=null) {
this.childlist = childList;
}else {
this.childlist=new ArrayList<>();
}
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public List<TreeNode<T>> getChildlist() {
return childlist;
}
public void setChildlist(List<TreeNode<T>> childlist) {
this.childlist = childlist;
}
}
public class GTree<T> {
// 根节点
public TreeNode<T> root = null;
// 插入
public boolean insert(TreeNode<T> parent, TreeNode<T> node) {
if (root == null) {
root = node;
return true;
} else {
if (findOne(root, parent)) {
// 留待考虑
// TODO 这里会不会直接修改节点的list,待考虑
return parent.getChildlist().add(node);
}
}
return false;
}
/**
*
* @param tRoot要参照的根节点
* @param one要查找的节点
* @return 是否存在这个节点
*/
public boolean findOne(TreeNode<T> tRoot, TreeNode<T> one) {
boolean b = false;
// 参照根结点为空,则该节点一定不存在
if (tRoot == null) {
return false;
}
//
if (tRoot == one) {
return true;
}
if (tRoot.getChildlist() != null) {
int length = tRoot.getChildlist().size();
for (int i = 0; i < length; i++) {
TreeNode<T> node = tRoot.getChildlist().get(i);
if (node == one) {
return true;
} else {
if (node.getChildlist().size() != 0) {
b = b || findOne(node, one);
}
}
}
} else {
return false;
}
return b;
}
// 遍历
/**
*
* @param root
* 根节点
* @param l
* 层数
*/
public void Travelsal(TreeNode<String> root, int l) {
int temp = l * 10;
if (root != null) {
if (l == 1) {
System.out.printf("|--%-10s--", root.getValue().toString());
}
if (root.getChildlist() != null && root.getChildlist().size() != 0) {
l++;
int length = root.getChildlist().size();
for (int i = 0; i < length; i++) {
TreeNode<String> node = root.getChildlist().get(i);
System.out.printf("|--%-10s--", node.getValue());
Travelsal(node, l);
System.out.print("\n");
int temp1 = temp;
temp = temp + (temp / 10) * 5;
System.out.printf("%-" + temp + "s", " ");
temp = temp1;
}
}
}
}
public void Travelsal1(TreeNode<String> root, int l) {
System.out.print("|");
int length = l * 3;
for (int i = 0; i < length + 1; i++) {
System.out.print("-");
}
System.out.println(root.getValue());
int clength = root.getChildlist().size();
for (int j = 0; j < clength; j++) {
Travelsal1(root.getChildlist().get(j), l + 1);
}
}
}