归纳决策树ID3(Java实现)

ID3就不介绍了,最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

001package dt;
002
003import java.io.BufferedReader;
004import java.io.File;
005import java.io.FileReader;
006import java.io.FileWriter;
007import java.io.IOException;
008import java.util.ArrayList;
009import java.util.Iterator;
010import java.util.LinkedList;
011import java.util.List;
012import java.util.regex.Matcher;
013import java.util.regex.Pattern;
014
015import org.dom4j.Document;
016import org.dom4j.DocumentHelper;
017import org.dom4j.Element;
018import org.dom4j.io.OutputFormat;
019import org.dom4j.io.XMLWriter;
020
021public class ID3 {
022    private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称
023    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
024    private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据
025    int decatt; // 决策变量在属性集中的索引
026    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
027
028    Document xmldoc;
029    Element root;
030
031    public ID3() {
032        xmldoc = DocumentHelper.createDocument();
033        root = xmldoc.addElement("root");
034        root.addElement("DecisionTree").addAttribute("value", "null");
035    }
036
037    public static void main(String[] args) {
038        ID3 inst = new ID3();
039        inst.readARFF(new File("/home/orisun/test/weather.nominal.arff"));
040        inst.setDec("play");
041        LinkedList<Integer> ll=new LinkedList<Integer>();
042        for(int i=0;i<inst.attribute.size();i++){
043            if(i!=inst.decatt)
044                ll.add(i);
045        }
046        ArrayList<Integer> al=new ArrayList<Integer>();
047        for(int i=0;i<inst.data.size();i++){
048            al.add(i);
049        }
050        inst.buildDT("DecisionTree", "null", al, ll);
051        inst.writeXML("/home/orisun/test/dt.xml");
052        return;
053    }
054
055    //读取arff文件,给attribute、attributevalue、data赋值
056    public void readARFF(File file) {
057        try {
058            FileReader fr = new FileReader(file);
059            BufferedReader br = new BufferedReader(fr);
060            String line;
061            Pattern pattern = Pattern.compile(patternString);
062            while ((line = br.readLine()) != null) {
063                Matcher matcher = pattern.matcher(line);
064                if (matcher.find()) {
065                    attribute.add(matcher.group(1).trim());
066                    String[] values = matcher.group(2).split(",");
067                    ArrayList<String> al = new ArrayList<String>(values.length);
068                    for (String value : values) {
069                        al.add(value.trim());
070                    }
071                    attributevalue.add(al);
072                } else if (line.startsWith("@data")) {
073                    while ((line = br.readLine()) != null) {
074                        if(line=="")
075                            continue;
076                        String[] row = line.split(",");
077                        data.add(row);
078                    }
079                } else {
080                    continue;
081                }
082            }
083            br.close();
084        } catch (IOException e1) {
085            e1.printStackTrace();
086        }
087    }
088
089    //设置决策变量
090    public void setDec(int n) {
091        if (n < 0 || n >= attribute.size()) {
092            System.err.println("决策变量指定错误。");
093            System.exit(2);
094        }
095        decatt = n;
096    }
097    public void setDec(String name) {
098        int n = attribute.indexOf(name);
099        setDec(n);
100    }
101
102    //给一个样本(数组中是各种情况的计数),计算它的熵
103    public double getEntropy(int[] arr) {
104        double entropy = 0.0;
105        int sum = 0;
106        for (int i = 0; i < arr.length; i++) {
107            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
108            sum += arr[i];
109        }
110        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
111        entropy /= sum;
112        return entropy;
113    }
114
115    //给一个样本数组及样本的算术和,计算它的熵
116    public double getEntropy(int[] arr, int sum) {
117        double entropy = 0.0;
118        for (int i = 0; i < arr.length; i++) {
119            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
120        }
121        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
122        entropy /= sum;
123        return entropy;
124    }
125
126    public boolean infoPure(ArrayList<Integer> subset) {
127        String value = data.get(subset.get(0))[decatt];
128        for (int i = 1; i < subset.size(); i++) {
129            String next=data.get(subset.get(i))[decatt];
130            //equals表示对象内容相同,==表示两个对象指向的是同一片内存
131            if (!value.equals(next))
132                return false;
133        }
134        return true;
135    }
136
137    // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
138    public double calNodeEntropy(ArrayList<Integer> subset, int index) {
139        int sum = subset.size();
140        double entropy = 0.0;
141        int[][] info = new int[attributevalue.get(index).size()][];
142        for (int i = 0; i < info.length; i++)
143            info[i] = new int[attributevalue.get(decatt).size()];
144        int[] count = new int[attributevalue.get(index).size()];
145        for (int i = 0; i < sum; i++) {
146            int n = subset.get(i);
147            String nodevalue = data.get(n)[index];
148            int nodeind = attributevalue.get(index).indexOf(nodevalue);
149            count[nodeind]++;
150            String decvalue = data.get(n)[decatt];
151            int decind = attributevalue.get(decatt).indexOf(decvalue);
152            info[nodeind][decind]++;
153        }
154        for (int i = 0; i < info.length; i++) {
155            entropy += getEntropy(info[i]) * count[i] / sum;
156        }
157        return entropy;
158    }
159
160    // 构建决策树
161    public void buildDT(String name, String value, ArrayList<Integer> subset,
162            LinkedList<Integer> selatt) {
163        Element ele = null;
164        @SuppressWarnings("unchecked")
165        List<Element> list = root.selectNodes("//"+name);
166        Iterator<Element> iter=list.iterator();
167        while(iter.hasNext()){
168            ele=iter.next();
169            if(ele.attributeValue("value").equals(value))
170                break;
171        }
172        if (infoPure(subset)) {
173            ele.setText(data.get(subset.get(0))[decatt]);
174            return;
175        }
176        int minIndex = -1;
177        double minEntropy = Double.MAX_VALUE;
178        for (int i = 0; i < selatt.size(); i++) {
179            if (i == decatt)
180                continue;
181            double entropy = calNodeEntropy(subset, selatt.get(i));
182            if (entropy < minEntropy) {
183                minIndex = selatt.get(i);
184                minEntropy = entropy;
185            }
186        }
187        String nodeName = attribute.get(minIndex);
188        selatt.remove(new Integer(minIndex));
189        ArrayList<String> attvalues = attributevalue.get(minIndex);
190        for (String val : attvalues) {
191            ele.addElement(nodeName).addAttribute("value", val);
192            ArrayList<Integer> al = new ArrayList<Integer>();
193            for (int i = 0; i < subset.size(); i++) {
194                if (data.get(subset.get(i))[minIndex].equals(val)) {
195                    al.add(subset.get(i));
196                }
197            }
198            buildDT(nodeName, val, al, selatt);
199        }
200    }
201
202    // 把xml写入文件
203    public void writeXML(String filename) {
204        try {
205            File file = new File(filename);
206            if (!file.exists())
207                file.createNewFile();
208            FileWriter fw = new FileWriter(file);
209            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
210            XMLWriter output = new XMLWriter(fw, format);
211            output.write(xmldoc);
212            output.close();
213        } catch (IOException e) {
214            System.out.println(e.getMessage());
215        }
216    }
217}

最终生成的文件如下:

view source print ?
<?xml version="1.0" encoding="UTF-8"?>
<root>
  <DecisionTree value="null">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值