归纳决策树ID3(Java实现)

42 篇文章 0 订阅
42 篇文章 1 订阅

先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。

table 1

outlooktemperaturehumiditywindyplay
sunnyhothighFALSEno
sunnyhothighTRUEno
overcasthothighFALSEyes
rainymildhighFALSEyes
rainycoolnormalFALSEyes
rainycoolnormalTRUEno
overcastcoolnormalTRUEyes
sunnymildhighFALSEno
sunnycoolnormalFALSEyes
rainymildnormalFALSEyes
sunnymildnormalTRUEyes
overcastmildhighTRUEyes
overcasthotnormalFALSEyes
rainymildhighTRUEno

这个问题当然可以用朴素贝叶斯法求解,分别计算在给定天气条件下打球和不打球的概率,选概率大者作为推测结果。

现在我们使用ID3归纳决策树的方法来求解该问题。

预备知识:信息熵

熵是无序性(或不确定性)的度量指标。假如事件A的全概率划分是(A1,A2,...,An),每部分发生的概率是(p1,p2,...,pn),那信息熵定义为:

通常以2为底数,所以信息熵的单位是bit。

补充两个对数去处公式:

ID3算法

构造树的基本想法是随着树深度的增加,节点的熵迅速地降低。熵降低的速度越快越好,这样我们有望得到一棵高度最矮的决策树。

在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:

属性有4个:outlook,temperature,humidity,windy。我们首先要决定哪个属性作树的根节点。

对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlooktemperaturehumiditywindyplay
 yesno yesno yesno yesnoyesno
sunny23hot22high34FALSE6295
overcast40mild42normal61TRUR33  
rainy32cool31        

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

Java实现

最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

实验用的数据文件:

@relation weather.symbolic

@attribute outlook {sunny, overcast, rainy}
@attribute temperature {hot, mild, cool}
@attribute humidity {high, normal}
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}

@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no

程序代码:

package schoolarship;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;

public class ID3 {
	//存储属性的名称,这里判别变量和决策变量一律称为“属性”
	private ArrayList<String> attribute = new ArrayList<String>();
	//存储每个属性(都是离散变量)的取值集合
	private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>();
	//存储所有的训练数据,这是一个二维数组
	private ArrayList<String[]> data = new ArrayList<String[]>();
	//决策变量的属性列表中的索引号
	int decatt;
	//用于匹配ARFF文件中的@attribute行
	public static final String patternString = "@attribute(.*)[{](.*?)[}]";

	//使用Dom4j读写XML文件
	Document xmldoc;
	Element root;

	//构造函数中初始化Dom元素
	public ID3() {
		xmldoc = DocumentHelper.createDocument();
		root = xmldoc.addElement("root");
		root.addElement("DecisionTree").addAttribute("value", "null");
	}

	public static void main(String[] args) {
		ID3 inst = new ID3();
		//读入训练文件
		inst.readARFF(new File("d:\\weather.arff"));
		//设置决策变量的名称
		inst.setDec("play");
		//将所有属性(决策变量除外)的索引号存入ll
		LinkedList<Integer> ll = new LinkedList<Integer>();
		for (int i = 0; i < inst.attribute.size(); i++) {
			if (i != inst.decatt)
				ll.add(i);
		}
		//将全部训练数据的序号存入al
		ArrayList<Integer> al = new ArrayList<Integer>();
		for (int i = 0; i < inst.data.size(); i++) {
			al.add(i);
		}
		//递归构建决策树
		inst.buildDT(inst.root, al, ll);
		//将决策树写入XML文件
		inst.writeXML("d:\\dt.xml");
	}

	//读取输入文件,为全局变量attribute、attributevalue和data赋值
	public void readARFF(File file) {
		try {
			FileInputStream fis = new FileInputStream(file);
			InputStreamReader isr = new InputStreamReader(fis,
					initBookEncode(fis));
			BufferedReader br = new BufferedReader(isr);
			String line;
			Pattern pattern = Pattern.compile(patternString);
			while ((line = br.readLine()) != null) {
				Matcher matcher = pattern.matcher(line);
				if (matcher.find()) {
					attribute.add(matcher.group(1).trim());
					String[] values = matcher.group(2).split(",");
					ArrayList<String> al = new ArrayList<String>(values.length);
					for (String value : values) {
						al.add(value.trim());
					}
					attributevalue.add(al);
				} else if (line.startsWith("@data")) {
					while ((line = br.readLine()) != null) {
						if (line.equals(""))
							continue;
						String[] row = line.split(",");
						data.add(row);
					}
				} else {
					continue;
				}
			}
			br.close();
		} catch (IOException e1) {
			e1.printStackTrace();
		}
	}

	//将参数n赋给全局变量decatt
	public void setDec(int n) {
		if (n < 0 || n >= attribute.size()) {
			System.err.println("给定的决策变量名称有误");
			System.exit(2);
		}
		decatt = n;
	}

	//根据属性的名称设置全局变量decatt
	public void setDec(String name) {
		int n = attribute.indexOf(name);
		setDec(n);
	}

	//计算信息熵。arr中存储各种情况的频数
	public double getEntropy(int[] arr) {
		int sum = 0;
		for (int i = 0; i < arr.length; i++) {
			sum += arr[i];
		}
		return getEntropy(arr, sum);
	}

	//计算信息熵。arr中存储各种情况的频数,sum给出频数的总和
	public double getEntropy(int[] arr, int sum) {
		if (sum == 0)
			return 0;
		double entropy = 0.0;
		for (int i = 0; i < arr.length; i++) {
			//加上Double.MIN_VALUE是为了防止出现log(0)的情况
			entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)
					/ Math.log(2);
		}
		entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);
		entropy /= sum;
		//由于上面加了Double.MIN_VALUE,所以算出来的熵可能会略大于1
		if (entropy > 1 && entropy - 1 < 0.00001)
			entropy = 1;
		return entropy;
	}

	//subset给写训练数据的一个子集(subset中存储的是每条数据的索引号),判断这些子集的决策变量值是否都相同
	public boolean infoPure(ArrayList<Integer> subset) {
		String value = data.get(subset.get(0))[decatt];
		for (int i = 1; i < subset.size(); i++) {
			String next = data.get(subset.get(i))[decatt];
			if (!value.equals(next))
				return false;
		}
		return true;
	}

	/**
	 * 计算节点的信息熵
	 * @param subset 节点上所包含的数据子集
	 * @param index 节点以第index个属性作为判断的依据
	 * @return 节点的信息熵
	 */
	public double calNodeEntropy(ArrayList<Integer> subset, int index) {
		int sum = subset.size();
		double entropy = 0.0;
		int[][] info = new int[attributevalue.get(index).size()][];
		for (int i = 0; i < info.length; i++)
			info[i] = new int[attributevalue.get(decatt).size()];
		int[] count = new int[attributevalue.get(index).size()];
		for (int i = 0; i < sum; i++) {
			int n = subset.get(i);
			String nodevalue = data.get(n)[index];
			int nodeind = attributevalue.get(index).indexOf(nodevalue);
			count[nodeind]++;
			String decvalue = data.get(n)[decatt];
			int decind = attributevalue.get(decatt).indexOf(decvalue);
			info[nodeind][decind]++;
		}
		for (int i = 0; i < info.length; i++) {
			entropy += getEntropy(info[i]) * count[i] / sum;
		}
		return entropy;
	}

	// 递归构建决策树
	public void buildDT(Element ele, ArrayList<Integer> subset,
			LinkedList<Integer> selatt) {
		
		//指定name和value的节点不包含数据子集时,递归可以终止。同时要删除该节点
		if (subset.size() == 0){
			ele.getParent().remove(ele);
			return;
		}
		
		//selatt.size() == 0说明树已经达到最大的深度,即所有判别属性都已经用完了。
		//这个时候递归还没有终止说明训练数据中存在判别属性值完全相同,决策属性值却不相同的情况,取决策属性值最多的情况为最终结果
		if(selatt.size() == 0){
			Map<String,Integer> map=new HashMap<String,Integer>();
			for(int i:subset){
				String key=data.get(i)[decatt];
				Integer v=map.get(key);
				if(v!=null)
					map.put(key, v+1);
				else
					map.put(key, 1);
			}
			
			String decision="should not appear";
			int maxCount=-1;
			Set<Entry<String,Integer>> set=map.entrySet();
			for(Entry<String,Integer> entry:set){
				if(entry.getValue()>maxCount){
					maxCount=entry.getValue();
					decision=entry.getKey();
				}
			}
			ele.setText(decision);
			return; 
		}
		
		//如果节点是纯的,那么就到达叶子节点了,给出决策,不需要继续递归了
		if (infoPure(subset)) {
			ele.setText(data.get(subset.get(0))[decatt]);		
			return;
		}
		
		//选择下一个用于判别的属性。应该选熵最小的,因为这样信息增益最大
		int minIndex = -1;
		double minEntropy = Double.MAX_VALUE;
		for (int i = 0; i < selatt.size(); i++) {
			if (i == decatt)
				continue;
			double entropy = calNodeEntropy(subset, selatt.get(i));
			if (entropy < minEntropy) {
				minIndex = selatt.get(i);
				minEntropy = entropy;
			}
		}
		String nodeName = attribute.get(minIndex);
		
		//每次递归时selatt都会少一个元素,即去除刚刚选择的判别属性
		selatt.remove(new Integer(minIndex));
		
		//刚刚选择的属性有多少种取值,该节点就有多少个分枝。遍历这些分枝,递归完善子树。
		ArrayList<String> attvalues = attributevalue.get(minIndex);
		for (String val : attvalues) {
			Element child=ele.addElement(nodeName).addAttribute("value", val);
			ArrayList<Integer> al = new ArrayList<Integer>();
			for (int i = 0; i < subset.size(); i++) {
				if (data.get(subset.get(i))[minIndex].equals(val)) {
					al.add(subset.get(i));
				}
			}
			//注意bBuildDT()里面selatt会被改变,所以每次传递这个参数的时候要进行深复制
			buildDT(child, al, new LinkedList<Integer>(selatt));
		}
	}

	//将Dom写入XML文件
	public void writeXML(String filename) {
		try {
			File file = new File(filename);
			if (!file.exists())
				file.createNewFile();
			FileWriter fw = new FileWriter(file);
			OutputFormat format = OutputFormat.createPrettyPrint();
			XMLWriter output = new XMLWriter(fw, format);
			output.write(xmldoc);
			output.close();
		} catch (IOException e) {
			System.out.println(e.getMessage());
		}
	}

	/*正面这两个函数用于正确读取中文文件*/
	
	String changeToGBK(String ss, String code) {
		String temp = null;
		try {
			temp = new String(ss.getBytes(), code);
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		}
		return temp;
	}

	public String initBookEncode(FileInputStream fileInputStream) {
		String encode = "gb2312";
		try {
			byte[] head = new byte[3];
			fileInputStream.read(head);
			if (head[0] == -17 && head[1] == -69 && head[2] == -65)
				encode = "UTF-8";
			else if (head[0] == -1 && head[1] == -2)
				encode = "UTF-16";
			else if (head[0] == -2 && head[1] == -1)
				encode = "Unicode";
		} catch (IOException e) {
			System.out.println(e.getMessage());

		}
		return encode;
	}
}


最终生成的文件如下:

<?xml version="1.0" encoding="UTF-8"?>

<root>
  <DecisionTree value="null">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>

用图形象地表示就是:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值