用Java实现的ID3算法

主要实现 



import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class ID3 {
	public static GTree<String> tree = new GTree<>();// 一颗通用树
	public static String[] attribute;// 自变因素类别列表(outlook,temperature,humidity,windy)
	public String[] valuename;// 因变因素列表(play)
	public String[] value = new String[2];// 决策值(YES,NO)因变因素
	public static List<String[]> data;// 数据
	public static List<Set<String>> clomnSetList;// 自变因素

	// 初始化数据

	{
		data = getData("test2.txt");
		String[][] a = new String[data.size()][];
		int ii = 0;
		for (String[] s : data) {
			a[ii++] = s;
		}
		// 初始化自变因素列表
		clomnSetList = getClomnValueSet(a);
	}

	// 获取数据
	/**
	 * 获得数据
	 * 
	 * @param path
	 *            本地文件路径 (仅支持文本文件)
	 * @return List<String[]>
	 */
	public List<String[]> getData(String path) {
		List<String[]> data = new ArrayList<>();
		File f = new File(path);
		FileReader fr;
		try {
			fr = new FileReader(f);
			BufferedReader bfr = new BufferedReader(fr);
			String[] firstLine = bfr.readLine().split(",");
			attribute = new String[firstLine.length - 1];
			valuename = new String[1];
			// 初始化自变因素列表
			for (int i = 0; i < firstLine.length - 1; i++) {
				attribute[i] = firstLine[i];
			}
			// 初始因变因素
			valuename[0] = firstLine[firstLine.length - 1];
			// 初始化数据
			String nextline;
			while ((nextline = bfr.readLine()) != null) {
				String[] d = nextline.split(",");
				if (value[0] != d[d.length - 1] && value[1] != d[d.length - 1]) {
					if (value[0] == null) {
						value[0] = d[d.length - 1];
					} else if (value[1] == null && (!d[d.length - 1].equals(value[0]))) {
						value[1] = d[d.length - 1];
					}
				}
				data.add(d);
			}
		} catch (Exception e) {

			System.out.println("文件路径不正确");
		}

		return data;
	}

	// 获取自变因素一列数据项信息熵
	/**
	 * 
	 * @param data
	 *            数据源
	 * @param cloumn
	 *            一列自变因素
	 * @param index
	 *            列指针
	 * @return 一类自变因素信息熵
	 */
	public double getGainClomn(List<String[]> data, Set<String> cloumn, int index) {
		double result = 0;
		for (String s : cloumn) {
			double entropy = getGain(data, s, index);
			double probility = getpro(data, s, index);
			result = result + entropy * probility;
		}
		return result;
	}

	// 获取一列自变因素各因素熵值,并进行排序
	/**
	 * 
	 * @param data
	 *            数据源
	 * @param cloumn
	 *            一列自变因素
	 * @param index
	 *            列指针
	 * @return key:value形式的二维矩阵
	 */
	public String[][] getGainClomnEntropy(List<String[]> data, Set<String> cloumn, int index) {
		String[][] result = new String[cloumn.size()][2];
		int i = 0;
		for (String s : cloumn) {
			int j = 0;
			double value = getGain(data, s, index);
			result[i][j++] = s;
			result[i++][j] = "" + value;
		}

		// 排序

		return sort(result);
	}

	// key:value形式二维数组排序方法
	public String[][] sort(String[][] a) {
		Arrays.sort(a, new Comparator<String[]>() {

			@Override
			public int compare(String[] o1, String[] o2) {

				return Double.compare(Double.parseDouble(o2[1]), Double.parseDouble(o1[1]));
			}
		});
		return a;

	}

	// 获取指定数据信息熵
	/**
	 * 
	 * @param data
	 *            数据源
	 * @param one
	 *            要求取熵的一个自变因素
	 * @param index
	 *            所在列指针
	 * @return 熵
	 */
	public double getGain(List<String[]> data, String one, int index) {
		double count = 0;
		int count1 = 0;
		int count2 = 0;
		for (String[] d : data) {

			if (d[index].trim().equals(one.trim())) {
				count++;
				if (d[d.length - 1].trim().equals(value[0].trim())) {
					count1++;
				}
				if (d[d.length - 1].trim().equals(value[1].trim())) {
					count2++;
				}
			}

		}
		double probability1 = Double.parseDouble("" + count1) / count;// 决策1概率
		double probability2 = Double.parseDouble("" + count2) / count;// 决策2概率
		if (probability1 == 0) {
			return 0;
		}
		if (probability1 == 1) {
			return 1;
		}
		double result = -probability1 * (Math.log(probability1) / Math.log(2))
				- probability2 * (Math.log(probability2) / Math.log(2));
		return result;
	}

	// 获取指定自变因素概率
	/**
	 * 获取指定自变因素概率
	 * 
	 * @param data
	 *            数据源
	 * @param one
	 *            自变因素
	 * @param index
	 *            所在列指针
	 * @return 概率
	 */
	public double getpro(List<String[]> data, String one, int index) {
		double count = data.size();
		int count1 = 0;
		for (String[] d : data) {

			if (d[index].trim().equals(one.trim())) {
				count1++;

			}

		}
		return Double.parseDouble("" + count1) / count;

	}

	// 获取当前文件系统信息熵
	/**
	 * 获取当前文件系统信息熵
	 * 
	 * @param data
	 *            数据源(原始数据)
	 * @return 系统信息熵
	 */
	public double getGain(List<String[]> data) {
		double count = data.size();
		int count1 = 0;
		int count2 = 0;
		for (String[] d : data) {

			if (d[d.length - 1].trim().equals(value[0].trim())) {
				count1++;
			}
			if (d[d.length - 1].trim().equals(value[1].trim())) {
				count2++;
			}
		}
		double probability1 = Double.parseDouble("" + count1) / count;
		double probability2 = Double.parseDouble("" + count2) / count;
		double result = -probability1 * (Math.log(probability1) / Math.log(2))
				- probability2 * (Math.log(probability2) / Math.log(2));
		return result;
	}

	// 获取信息增益
	/**
	 * 
	 * @param data
	 *            数据源
	 * @param cloumn
	 *            一列自变因素
	 * @param index
	 *            列指针
	 * @return 信息增益
	 */

	public Double getGainCreat(List<String[]> data, Set<String> cloumn, int index) {
		return getGain(data) - getGainClomn(data, cloumn, index);
	}

	// 获取当前列的数据有哪些
	/**
	 * 初始化一列自变因素列表
	 * 
	 * @param a数据源
	 * @return [{sunny, overcast, rainy},{.....},...]
	 */
	public List<Set<String>> getClomnValueSet(String[][] a) {
		a = reverdraSort(a);
		print(a);
		List<Set<String>> list = new ArrayList<>();
		for (int i = 0; i < a.length - 1; i++) {
			Set<String> set = new HashSet<>();
			for (int j = i; j < a[i].length; j++) {
				set.add(a[i][j]);
			}
			list.add(set);
		}
		return list;

	}
	// 二维数组列行倒置排序法

	public String[][] reverdraSort(String[][] a) {
		int l1 = a.length;
		int l2 = a[0].length;
		String[][] a1 = new String[l2][l1];
		for (int i = 0; i < l2; i++) {

			for (int j = 0; j < l1; j++) {

				a1[i][j] = a[j][i];
			}

		}
		return a1;
	}
	// 打印二维数组方法

	public static void print(String arr[][]) {

		for (int i = 0; i < arr.length; i++) {

			for (int j = 0; j < arr[i].length; j++) {

				System.out.print(arr[i][j] + "、");

			}

			System.out.println();

		}

		System.out.println();

	}

	// 构建树
	/**
	 * 递归构建决策树
	 * 
	 * @param data1
	 *            数据源
	 * @param root1
	 *            当前根节点
	 * @param clomnSetList
	 *            自变因素矩阵
	 * @param attribute
	 *            自变因素类别列表
	 * @return
	 */
	public TreeNode<String> makeTree(List<String[]> data1, TreeNode<String> root1, List<Set<String>> clomnSetList,
			String[] attribute) {

		// 找信息熵最大的自变因素
		if (clomnSetList.size() > 1) {
			// System.out.println(root);
			Double max = 0D;
			int maxIndex = 0;// 信息熵最大因素下标
			for (int i = 0; i < clomnSetList.size(); i++) {
				double temp = getGainCreat(data1, clomnSetList.get(i), i);
				// if(temp==0) {
				// return;
				// }
				if (max < temp) {
					max = temp;
					maxIndex = i;
				}
			}
			TreeNode<String> n1 = null;
			if (root1 == null) {
				n1 = new TreeNode<>(attribute[maxIndex], null);
				root1 = n1;
				tree.insert(null, root1);
			} else {
				n1 = new TreeNode<>(attribute[maxIndex], null);

				tree.insert(root1, n1);
			}

			// 获取此自变因素的决策数组(熵数组)
			String[][] device = getGainClomnEntropy(data1, clomnSetList.get(maxIndex), maxIndex);
			for (int i = 0; i < device.length; i++) {
				TreeNode<String> n = new TreeNode<>(device[i][0], null);
				tree.insert(n1, n);
				if (Double.parseDouble(device[i][1]) == 0) {
					TreeNode<String> n2 = new TreeNode<>(value[1], null);
					tree.insert(n, n2);
				} else if (Double.parseDouble(device[i][1]) == 1) {
					TreeNode<String> n2 = new TreeNode<>(value[0], null);
					tree.insert(n, n2);
				}

				else {
					// 重建数据
					String v = device[i][0];

					List<String[]> ndata1 = new ArrayList<>();
					for (String[] s : data1) {
						if (s[maxIndex].trim().equals(v.trim())) {
							String[] nn = new String[s.length - 1];
							for (int j = 0; j < maxIndex; j++) {
								nn[j] = s[j];
							}
							for (int j = maxIndex; j < nn.length; j++) {
								nn[j] = s[j + 1];
							}

							ndata1.add(nn);
						}
					}

					String[] newa = new String[attribute.length - 1];
					for (int k = 0; k < maxIndex; k++) {
						newa[k] = attribute[k];
					}
					for (int k = maxIndex; k < newa.length; k++) {
						newa[k] = attribute[k + 1];
					}

					List<Set<String>> clomnSetListnew = new ArrayList<>(clomnSetList);
					clomnSetListnew.remove(clomnSetListnew.get(maxIndex));

					makeTree(ndata1, n, clomnSetListnew, newa);

				}

			}

		}
		return root1;

	}

	public static void main(String[] args) {
		ID3 id3 = new ID3();
		TreeNode<String> treenode = id3.makeTree(data, null, clomnSetList, attribute);

		tree.Travelsal(treenode, 1);
	}

}

数据结构支持



import java.util.ArrayList;
import java.util.List;

//通用树的节点
public class TreeNode<T>{
	private Object value;//数据区
	private List<TreeNode<T>> childlist;//孩子节点指针集合
	public TreeNode(){	
		value = null;
		childlist = new ArrayList<>();
	}
	
	public TreeNode(Object value,List<TreeNode<T>> childList) {
		this.value = value;
		if(childList!=null) {
			this.childlist = childList;
		}else {
			this.childlist=new ArrayList<>();
		}
		
	}

	public Object getValue() {
		return value;
	}

	public void setValue(Object value) {
		this.value = value;
	}

	public List<TreeNode<T>> getChildlist() {
		return childlist;
	}

	public void setChildlist(List<TreeNode<T>> childlist) {
		this.childlist = childlist;
	}
	
	
}


public class GTree<T> {
	// 根节点
	public TreeNode<T> root = null;

	// 插入
	public boolean insert(TreeNode<T> parent, TreeNode<T> node) {
		if (root == null) {
			root = node;
			return true;
		} else {
			if (findOne(root, parent)) {

				// 留待考虑
				// TODO 这里会不会直接修改节点的list,待考虑
				return parent.getChildlist().add(node);
			}
		}
		return false;
	}

	/**
	 * 
	 * @param tRoot要参照的根节点
	 * @param one要查找的节点
	 * @return 是否存在这个节点
	 */
	public boolean findOne(TreeNode<T> tRoot, TreeNode<T> one) {
		boolean b = false;
		// 参照根结点为空,则该节点一定不存在
		if (tRoot == null) {
			return false;
		}
		//
		if (tRoot == one) {
			return true;
		}

		if (tRoot.getChildlist() != null) {
			int length = tRoot.getChildlist().size();
			for (int i = 0; i < length; i++) {
				TreeNode<T> node = tRoot.getChildlist().get(i);
				if (node == one) {
					return true;
				} else {
					if (node.getChildlist().size() != 0) {
						b = b || findOne(node, one);
					}
				}
			}
		} else {
			return false;
		}

		return b;
	}

	// 遍历
	/**
	 * 
	 * @param root
	 *            根节点
	 * @param l
	 *            层数
	 */
	public void Travelsal(TreeNode<String> root, int l) {
		int temp = l * 10;

		if (root != null) {
			if (l == 1) {
				System.out.printf("|--%-10s--", root.getValue().toString());
			}

			if (root.getChildlist() != null && root.getChildlist().size() != 0) {
				l++;
				int length = root.getChildlist().size();
				for (int i = 0; i < length; i++) {
					TreeNode<String> node = root.getChildlist().get(i);
					System.out.printf("|--%-10s--", node.getValue());

					Travelsal(node, l);

					System.out.print("\n");
					int temp1 = temp;
					temp = temp + (temp / 10) * 5;
					System.out.printf("%-" + temp + "s", " ");
					temp = temp1;

				}
			}
		}
	}

	public void Travelsal1(TreeNode<String> root, int l) {

		System.out.print("|");
		int length = l * 3;
		for (int i = 0; i < length + 1; i++) {
			System.out.print("-");
		}
		System.out.println(root.getValue());
		int clength = root.getChildlist().size();
		for (int j = 0; j < clength; j++) {
			Travelsal1(root.getChildlist().get(j), l + 1);
		}
	}

}

 

 

 

 

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

时长河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值