决策树——ID3算法的java实现

所谓决策树就是用树来帮助我们做决策,从树的根节点开始一级一级的访问节点,直到叶子节点,也就完成了决策的过程。

决策树算法是描述用已知的样本来构建决策树的过程,这边用比较经典的“气候—玩”的例子来说明,

描述气候有很多指标(天色、温度、湿度、风速),想得到的决策结论是是否能玩(yes or no),

这边有一组已知的样本,存于weather.nominal.arff文件中,如下:

@relation weather.symbolic
 
@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}
 
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
arff文件用的比较多,很多地方有介绍,主要就是属性和数据,

@attribute开头的是属性,有属性名、可选值;

@data之后的每一行都是数据,逗号分隔每一个属性。


下面是一个ID3算法的java实现

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;

public class ID3 {
	private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称
	private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
	private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据
	int decatt; // 决策变量在属性集中的索引
	public static final String patternString = "@attribute(.*)[{](.*?)[}]";

	Document xmldoc;
	Element root;

	public ID3() {
		xmldoc = DocumentHelper.createDocument();
		root = xmldoc.addElement("root");
		root.addElement("DecisionTree").addAttribute("value", "null");
	}

	// 读取arff文件,给attribute、attributevalue、data赋值
	public void readARFF(File file) {
		try {
			FileReader fr = new FileReader(file);
			BufferedReader br = new BufferedReader(fr);
			String line;
			Pattern pattern = Pattern.compile(patternString);
			while ((line = br.readLine()) != null) {
				Matcher matcher = pattern.matcher(line);
				if (matcher.find()) { // 读@attribute
					attribute.add(matcher.group(1).trim());
					String[] values = matcher.group(2).split(",");
					ArrayList<String> al = new ArrayList<String>(values.length);
					for (String value : values) {
						al.add(value.trim());
					}
					attributevalue.add(al);
				} else if (line.startsWith("@data")) { // 读@data
					while ((line = br.readLine()) != null) {
						if (line == "")
							continue;
						String[] row = line.split(",");
						data.add(row);
					}
				} else {
					continue;
				}
			}
			br.close();
		} catch (IOException e1) {
			e1.printStackTrace();
		}
	}

	// 设置决策变量
	public void setDec(int n) {
		if (n < 0 || n >= attribute.size()) {
			System.err.println("决策变量指定错误。");
			System.exit(2);
		}
		decatt = n;
	}

	public void setDec(String name) {
		int n = attribute.indexOf(name);
		setDec(n);
	}

	// 给一个样本(数组中是各种情况的计数),计算它的熵
	public double getEntropy(int[] arr) {
		double entropy = 0.0;
		int sum = 0;
		for (int i = 0; i < arr.length; i++) {
			entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
					/ Math.log(2);
			sum += arr[i];
		}
		entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
		entropy /= sum;
		return entropy;
	}

	// 给一个样本数组及样本的算术和,计算它的熵
	public double getEntropy(int[] arr, int sum) {
		double entropy = 0.0;
		for (int i = 0; i < arr.length; i++) {
			entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
					/ Math.log(2);
		}
		entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
		entropy /= sum;
		return entropy;
	}

	public boolean infoPure(ArrayList<Integer> subset) {
		String value = data.get(subset.get(0))[decatt];
		for (int i = 1; i < subset.size(); i++) {
			String next = data.get(subset.get(i))[decatt];
			// equals表示对象内容相同,==表示两个对象指向的是同一片内存
			if (!value.equals(next))
				return false;
		}
		return true;
	}

	// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
	public double calNodeEntropy(ArrayList<Integer> subset, int index) {
		int sum = subset.size();
		double entropy = 0.0;
		int[][] info = new int[attributevalue.get(index).size()][];
		for (int i = 0; i < info.length; i++)
			info[i] = new int[attributevalue.get(decatt).size()];
		int[] count = new int[attributevalue.get(index).size()];
		for (int i = 0; i < sum; i++) {
			int n = subset.get(i);
			String nodevalue = data.get(n)[index];
			int nodeind = attributevalue.get(index).indexOf(nodevalue);
			count[nodeind]++;
			String decvalue = data.get(n)[decatt];
			int decind = attributevalue.get(decatt).indexOf(decvalue);
			info[nodeind][decind]++;
		}
		for (int i = 0; i < info.length; i++) {
			entropy += getEntropy(info[i]) * count[i] / sum;
		}
		return entropy;
	}

	// 构建决策树
	public void buildDT(String name, String value, ArrayList<Integer> subset,
			LinkedList<Integer> selatt) {
		Element ele = null;
		@SuppressWarnings("unchecked")
		List<Element> list = root.selectNodes("//" + name);
		Iterator<Element> iter = list.iterator();
		while (iter.hasNext()) {
			ele = iter.next();
			if (ele.attributeValue("value").equals(value))
				break;
		}
		if (infoPure(subset)) {
			ele.setText(data.get(subset.get(0))[decatt]);
			return;
		}
		int minIndex = -1;
		double minEntropy = Double.MAX_VALUE;
		for (int i = 0; i < selatt.size(); i++) {
			if (i == decatt)
				continue;
			double entropy = calNodeEntropy(subset, selatt.get(i));
			if (entropy < minEntropy) {
				minIndex = selatt.get(i);
				minEntropy = entropy;
			}
		}
		String nodeName = attribute.get(minIndex);
		selatt.remove(new Integer(minIndex));
		ArrayList<String> attvalues = attributevalue.get(minIndex);
		for (String val : attvalues) {
			ele.addElement(nodeName).addAttribute("value", val);
			ArrayList<Integer> al = new ArrayList<Integer>();
			for (int i = 0; i < subset.size(); i++) {
				if (data.get(subset.get(i))[minIndex].equals(val)) {
					al.add(subset.get(i));
				}
			}
			buildDT(nodeName, val, al, selatt);
		}
	}

	// 把xml写入文件
	public void writeXML(String filename) {
		try {
			File file = new File(filename);
			if (!file.exists())
				file.createNewFile();
			FileWriter fw = new FileWriter(file);
			OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
			XMLWriter output = new XMLWriter(fw, format);
			output.write(xmldoc);
			output.close();
		} catch (IOException e) {
			System.out.println(e.getMessage());
		}
	}

	public static void main(String[] args) {
		ID3 inst = new ID3();
		inst.readARFF(new File(System.getProperty("user.dir") + "\\resource\\weather.nominal.arff"));
		inst.setDec("play");
		LinkedList<Integer> ll = new LinkedList<Integer>();
		for (int i = 0; i < inst.attribute.size(); i++) {
			if (i != inst.decatt)
				ll.add(i);
		}
		ArrayList<Integer> al = new ArrayList<Integer>();
		for (int i = 0; i < inst.data.size(); i++) {
			al.add(i);
		}
		inst.buildDT("DecisionTree", "null", al, ll);
		inst.writeXML(System.getProperty("user.dir") + "\\resource\\dt.xml");
		return;
	}
}

读样本文件weather.nominal.arff,然后调用buildDT构建决策树,将决策树描述成xml输出到dt.xml文件中,

结果如下:

<?xml version="1.0" encoding="UTF-8"?>

<root>
  <DecisionTree value="null">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值