说明
参考文章-归纳决策树ID3(Java实现),完成代码编写。
在原代码的基础上补充了预测函数,实现利用模型对新数据进行分类预测。
作者对ID3决策树的介绍-ID3决策树
决策树采用xml文件保存,使用Dom4J类库,点击下载
让Dom4J支持按XPath选择节点,还得引入包jaxen.jar,点击下载
源代码汇总,点击下载
思路
代码
输入文件采用ARFF格式,使用的训练数据文件如下:
train.arff
@relation weather.symbolic
@attribute outlook {sunny,overcast,rainy}
@attribute temperature {hot,mild,cool}
@attribute humidity {high,normal}
@attribute windy {
TRUE,FALSE}
@attribute play {yes,no}
@data
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
ARFF(Attribute-Relation File Format):格式简单明了,分为两部分,第一部分交代属性及取值范围,第二部分则是数据部分(data)。
由于只是测试代码效果,测试集(predict.arff)也是上述数据,只是将类标相关的数据移除了。
ID3类
package ID3;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.Character.Subset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.dom4j.Document;
import org.dom4j.DocumentHelper;
import org.dom4j.Element;
import org.dom4j.io.OutputFormat;
import org.dom4j.io.XMLWriter;
import org.w3c.dom.NodeList;
public class ID3 {
// 同时保留训练集和测试集的数据在模型中,防止训练集和测试集的列顺序不同
private ArrayList<String> trainAttribute = new ArrayList<String>(); // 存储训练集属性的名称
private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存储训练集每个属性的取值
private ArrayList<String> predictAttribute = new ArrayList<String>(); // 存储测试集属性的名称
private ArrayList<ArrayList<String>> predict_attributeValue = new ArrayList<ArrayList<String>>(); // 存储测试集每个属性的取值
private ArrayList<String[]> train_data = new ArrayList<String[]>(); // 训练集数据 ,即arff文件中的data字符串
private ArrayList<String[]> predict_data = new ArrayList<String[]>(); // 测试集数据
private String[] preLable;
int decatt; // 决策变量在属性集中的索引(即类标所在列)
public static final String patternString =