ID3就不介绍了,最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。
003 | import java.io.BufferedReader; |
005 | import java.io.FileReader; |
006 | import java.io.FileWriter; |
007 | import java.io.IOException; |
008 | import java.util.ArrayList; |
009 | import java.util.Iterator; |
010 | import java.util.LinkedList; |
011 | import java.util.List; |
012 | import java.util.regex.Matcher; |
013 | import java.util.regex.Pattern; |
015 | import org.dom4j.Document; |
016 | import org.dom4j.DocumentHelper; |
017 | import org.dom4j.Element; |
018 | import org.dom4j.io.OutputFormat; |
019 | import org.dom4j.io.XMLWriter; |
022 | private ArrayList<String> attribute = new ArrayList<String>(); |
023 | private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); |
024 | private ArrayList<String[]> data = new ArrayList<String[]>();; |
026 | public static final String patternString = "@attribute(.*)[{](.*?)[}]" ; |
032 | xmldoc = DocumentHelper.createDocument(); |
033 | root = xmldoc.addElement( "root" ); |
034 | root.addElement( "DecisionTree" ).addAttribute( "value" , "null" ); |
037 | public static void main(String[] args) { |
038 | ID3 inst = new ID3(); |
039 | inst.readARFF( new File( "/home/orisun/test/weather.nominal.arff" )); |
041 | LinkedList<Integer> ll= new LinkedList<Integer>(); |
042 | for ( int i= 0 ;i<inst.attribute.size();i++){ |
046 | ArrayList<Integer> al= new ArrayList<Integer>(); |
047 | for ( int i= 0 ;i<inst.data.size();i++){ |
050 | inst.buildDT( "DecisionTree" , "null" , al, ll); |
051 | inst.writeXML( "/home/orisun/test/dt.xml" ); |
056 | public void readARFF(File file) { |
058 | FileReader fr = new FileReader(file); |
059 | BufferedReader br = new BufferedReader(fr); |
061 | Pattern pattern = Pattern.compile(patternString); |
062 | while ((line = br.readLine()) != null ) { |
063 | Matcher matcher = pattern.matcher(line); |
064 | if (matcher.find()) { |
065 | attribute.add(matcher.group( 1 ).trim()); |
066 | String[] values = matcher.group( 2 ).split( "," ); |
067 | ArrayList<String> al = new ArrayList<String>(values.length); |
068 | for (String value : values) { |
069 | al.add(value.trim()); |
071 | attributevalue.add(al); |
072 | } else if (line.startsWith( "@data" )) { |
073 | while ((line = br.readLine()) != null ) { |
076 | String[] row = line.split( "," ); |
084 | } catch (IOException e1) { |
085 | e1.printStackTrace(); |
090 | public void setDec( int n) { |
091 | if (n < 0 || n >= attribute.size()) { |
092 | System.err.println( "决策变量指定错误。" ); |
097 | public void setDec(String name) { |
098 | int n = attribute.indexOf(name); |
103 | public double getEntropy( int [] arr) { |
104 | double entropy = 0.0 ; |
106 | for ( int i = 0 ; i < arr.length; i++) { |
107 | entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log( 2 ); |
110 | entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log( 2 ); |
116 | public double getEntropy( int [] arr, int sum) { |
117 | double entropy = 0.0 ; |
118 | for ( int i = 0 ; i < arr.length; i++) { |
119 | entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log( 2 ); |
121 | entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log( 2 ); |
126 | public boolean infoPure(ArrayList<Integer> subset) { |
127 | String value = data.get(subset.get( 0 ))[decatt]; |
128 | for ( int i = 1 ; i < subset.size(); i++) { |
129 | String next=data.get(subset.get(i))[decatt]; |
131 | if (!value.equals(next)) |
138 | public double calNodeEntropy(ArrayList<Integer> subset, int index) { |
139 | int sum = subset.size(); |
140 | double entropy = 0.0 ; |
141 | int [][] info = new int [attributevalue.get(index).size()][]; |
142 | for ( int i = 0 ; i < info.length; i++) |
143 | info[i] = new int [attributevalue.get(decatt).size()]; |
144 | int [] count = new int [attributevalue.get(index).size()]; |
145 | for ( int i = 0 ; i < sum; i++) { |
146 | int n = subset.get(i); |
147 | String nodevalue = data.get(n)[index]; |
148 | int nodeind = attributevalue.get(index).indexOf(nodevalue); |
150 | String decvalue = data.get(n)[decatt]; |
151 | int decind = attributevalue.get(decatt).indexOf(decvalue); |
152 | info[nodeind][decind]++; |
154 | for ( int i = 0 ; i < info.length; i++) { |
155 | entropy += getEntropy(info[i]) * count[i] / sum; |
161 | public void buildDT(String name, String value, ArrayList<Integer> subset, |
162 | LinkedList<Integer> selatt) { |
164 | @SuppressWarnings ( "unchecked" ) |
165 | List<Element> list = root.selectNodes( "//" +name); |
166 | Iterator<Element> iter=list.iterator(); |
167 | while (iter.hasNext()){ |
169 | if (ele.attributeValue( "value" ).equals(value)) |
172 | if (infoPure(subset)) { |
173 | ele.setText(data.get(subset.get( 0 ))[decatt]); |
177 | double minEntropy = Double.MAX_VALUE; |
178 | for ( int i = 0 ; i < selatt.size(); i++) { |
181 | double entropy = calNodeEntropy(subset, selatt.get(i)); |
182 | if (entropy < minEntropy) { |
183 | minIndex = selatt.get(i); |
184 | minEntropy = entropy; |
187 | String nodeName = attribute.get(minIndex); |
188 | selatt.remove( new Integer(minIndex)); |
189 | ArrayList<String> attvalues = attributevalue.get(minIndex); |
190 | for (String val : attvalues) { |
191 | ele.addElement(nodeName).addAttribute( "value" , val); |
192 | ArrayList<Integer> al = new ArrayList<Integer>(); |
193 | for ( int i = 0 ; i < subset.size(); i++) { |
194 | if (data.get(subset.get(i))[minIndex].equals(val)) { |
195 | al.add(subset.get(i)); |
198 | buildDT(nodeName, val, al, selatt); |
203 | public void writeXML(String filename) { |
205 | File file = new File(filename); |
207 | file.createNewFile(); |
208 | FileWriter fw = new FileWriter(file); |
209 | OutputFormat format = OutputFormat.createPrettyPrint(); |
210 | XMLWriter output = new XMLWriter(fw, format); |
211 | output.write(xmldoc); |
213 | } catch (IOException e) { |
214 | System.out.println(e.getMessage()); |
最终生成的文件如下:
<? xml version = "1.0" encoding = "UTF-8" ?> |
< DecisionTree value = "null" > |
< humidity value = "high" >no</ humidity > |
< humidity value = "normal" >yes</ humidity > |
< outlook value = "overcast" >yes</ outlook > |
< windy value = "TRUE" >no</ windy > |
< windy value = "FALSE" >yes</ windy > |