apriori算法是最基本的发现频繁项集的算法,它的名字也体现了它的思想——先验,采用逐层搜索迭代的方法,挖掘任何可能的项集,k项集用于挖掘k+1项集。
先验性质
频繁项集的所有非空子集也一定是频繁的
该性质体现了项集挖掘中的反单调性,如果k项集不是频繁的,那么k+1项集一定也不是。基于这一点,算法的基本思想为:
step 1:连接
为了搜索k项集,将k-1项集自连接产生候选的k项集,称为候选集。
为了有效的实现连接,首先对每一项进行排序。其次,若满足连接的条件,则进行连接。
连接的条件,前k-2项相同,k-1项不同
step 2:剪枝
k项集的每一个k-1项子集都存在与k-1项集,并且支持度满足最小支持度阀值。
伪代码:
C<k>:candidata itemset of size kJava代码实现方式:
L<k>:frequent itemset of size k
L<1>=frequent items
for(k=1;L<k>!=null;k++)
C<k+1>=candidates generated from L<k>
for transaction t in dataset
increment the count of all candidates in C<k+1> that are contained in t L<k+1>=candidates in C<k+1> with support>=min_support
return
抽象了一个项集实体类,并实现是否可以合并的方法,这个方法最初是使用TreeSet.headSet来实现的,但是在测试时发现性能瓶颈都产生在这个方法上,并造成OOM,很是费解,待研究清楚后总结一下。
001 | /** |
002 | * |
003 | */ |
004 | package org.waitingfortime.datamining.association; |
005 |
006 | import java.io.BufferedReader; |
007 | import java.io.File; |
008 | import java.io.FileNotFoundException; |
009 | import java.io.FileOutputStream; |
010 | import java.io.FileReader; |
011 | import java.io.IOException; |
012 | import java.io.PrintStream; |
013 | import java.util.ArrayList; |
014 | import java.util.HashMap; |
015 | import java.util.Iterator; |
016 | import java.util.List; |
017 | import java.util.Map; |
018 | import java.util.Set; |
019 | import java.util.TreeSet; |
020 |
021 | /** |
022 | * @author mazhiyuan |
023 | * |
024 | */ |
025 | public class Apriori { |
026 | private int minNum; // 最小支持数 |
027 | private List<Set<Integer>> records; |
028 | private String output; |
029 | private List<List<ItemSet>> result = new ArrayList<List<ItemSet>>(); |
030 |
031 | public Apriori( double minDegree, String input, String output) { |
032 | this .output = output; |
033 | init(input); |
034 | if (records.size() == 0 ) { |
035 | System.err.println( "不符合计算条件。退出!" ); |
036 | System.exit( 1 ); |
037 | } |
038 | minNum = ( int ) (minDegree * records.size()); |
039 | } |
040 |
041 | private void init(String path) { |
042 | // TODO Auto-generated method stub |
043 | records = new ArrayList<Set<Integer>>(); |
044 | try { |
045 | BufferedReader br = new BufferedReader( new FileReader( |
046 | new File(path))); |
047 |
048 | String line = null ; |
049 | Set<Integer> record; |
050 | while ((line = br.readLine()) != null ) { |
051 | if (! "" .equals(line.trim())) { |
052 | record = new TreeSet<Integer>(); |
053 | String[] items = line.split( " " ); |
054 | for (String item : items) { |
055 | record.add(Integer.valueOf(item)); |
056 | } |
057 | records.add(record); |
058 | } |
059 | } |
060 |
061 | br.close(); |
062 | } catch (IOException e) { |
063 | System.err.println( "读取事务文件失败。" ); |
064 | } |
065 | } |
066 |
067 | private List<ItemSet> first() { |
068 | // TODO Auto-generated method stub |
069 | List<ItemSet> first = new ArrayList<ItemSet>(); |
070 | Map<Integer, Integer> _first = new HashMap<Integer, Integer>(); |
071 | for (Set<Integer> si : records) |
072 | for (Integer i : si) { |
073 | if (_first.get(i) == null ) |
074 | _first.put(i, 1 ); |
075 | else |
076 | _first.put(i, _first.get(i) + 1 ); |
077 | } |
078 |
079 | for (Integer i : _first.keySet()) |
080 | if (_first.get(i) >= minNum) |
081 | first.add( new ItemSet(i, _first.get(i))); |
082 |
083 | return first; |
084 | } |
085 |
086 | private void loop(List<ItemSet> items) { |
087 | // TODO Auto-generated method stub |
088 | List<ItemSet> copy = new ArrayList<ItemSet>(items); |
089 | List<ItemSet> res = new ArrayList<ItemSet>(); |
090 | int size = items.size(); |
091 |
092 | // 连接 |
093 | for ( int i = 0 ; i < size; i++) |
094 | for ( int j = i + 1 ; j < size; j++) |
095 | if (copy.get(i).isMerge(copy.get(j))) { |
096 | ItemSet is = new ItemSet(copy.get(i)); |
097 | is.merge(copy.get(j).item.last()); |
098 | res.add(is); |
099 | } |
100 | // 剪枝 |
101 | pruning(copy, res); |
102 |
103 | if (res.size() != 0 ) { |
104 | result.add(res); |
105 | loop(res); |
106 | } |
107 | } |
108 |
109 | private void pruning(List<ItemSet> pre, List<ItemSet> res) { |
110 | // TODO Auto-generated method stub |
111 | // step 1 k项集的子集属于k-1项集 |
112 | Iterator<ItemSet> ir = res.iterator(); |
113 | while (ir.hasNext()) { |
114 | // 获取所有k-1项子集 |
115 | ItemSet now = ir.next(); |
116 | List<List<Integer>> ss = subSet(now); |
117 | // 判断是否在pre集中 |
118 | boolean flag = false ; |
119 | for (List<Integer> li : ss) { |
120 | if (flag) |
121 | break ; |
122 | for (ItemSet pis : pre) { |
123 | if (pis.item.containsAll(li)) { |
124 | flag = false ; |
125 | break ; |
126 | } |
127 | flag = true ; |
128 | } |
129 | } |
130 | if (flag) { |
131 | ir.remove(); |
132 | continue ; |
133 | } |
134 | // step 2 支持度 |
135 | int i = 0 ; |
136 | for (Set<Integer> sr : records) { |
137 | if (sr.containsAll(now.item)) |
138 | i++; |
139 |
140 | now.value = i; |
141 | } |
142 | if (now.value < minNum) |
143 | ir.remove(); |
144 | } |
145 | } |
146 |
147 | private List<List<Integer>> subSet(ItemSet is) { |
148 | // TODO Auto-generated method stub |
149 | List<Integer> li = new ArrayList<Integer>(is.item); |
150 | List<List<Integer>> res = new ArrayList<List<Integer>>(); |
151 | for ( int i = 0 , j = li.size(); i < j; i++) { |
152 | List<Integer> _li = new ArrayList<Integer>(li); |
153 | _li.remove(i); |
154 | res.add(_li); |
155 | } |
156 | return res; |
157 | } |
158 |
159 | private void output() throws FileNotFoundException { |
160 | if (result.size() == 0 ) { |
161 | System.err.println( "无结果集。退出!" ); |
162 | return ; |
163 | } |
164 | FileOutputStream out = new FileOutputStream(output); |
165 | PrintStream ps = new PrintStream(out); |
166 | for (List<ItemSet> li : result) { |
167 | ps.println( "=============频繁" +li.get( 0 ).item.size()+ "项集=============" ); |
168 | for (ItemSet is : li) |
169 | ps.println(is.item + " : " + is.value); |
170 | ps.println( "=====================================" ); |
171 | } |
172 | } |
173 |
174 | /** |
175 | * @param args |
176 | * @throws FileNotFoundException |
177 | */ |
178 | public static void main(String[] args) throws FileNotFoundException { |
179 | // TODO Auto-generated method stub |
180 | long begin = System.currentTimeMillis(); |
181 | Apriori apriori = new Apriori( 0.25 , |
182 | "/home/mazhiyuan/code/mushroom.dat" , |
183 | "/home/mazhiyuan/code/mout.data" ); |
184 | // apriori.first();//频繁1项集 |
185 | apriori.loop(apriori.first()); |
186 | apriori.output(); |
187 | System.out.println((System.currentTimeMillis()) - begin); |
188 | } |
189 | } |
190 |
191 | class ItemSet { |
192 | TreeSet<Integer> item; |
193 | int value; |
194 |
195 | ItemSet(ItemSet is) { |
196 | this .item = new TreeSet<Integer>(is.item); |
197 | } |
198 |
199 | ItemSet() { |
200 | item = new TreeSet<Integer>(); |
201 | } |
202 |
203 | ItemSet( int i, int v) { |
204 | this (); |
205 | merge(i); |
206 | setValue(v); |
207 | } |
208 |
209 | void setValue( int i) { |
210 | this .value = i; |
211 | } |
212 |
213 | void merge( int i) { |
214 | item.add(i); |
215 | } |
216 |
217 | boolean isMerge(ItemSet other) { |
218 | if (other == null || other.item.size() != item.size()) |
219 | return false ; |
220 | // 前k-1项相同,最后一项不同,满足连接条件 |
221 | /* |
222 | * Iterator<Integer> i = item.headSet(item.last()).iterator(); |
223 | * Iterator<Integer> o = |
224 | * other.item.headSet(other.item.last()).iterator(); while (i.hasNext() && |
225 | * o.hasNext()) if (i.next() != o.next()) return false; |
226 | */ |
227 | Iterator<Integer> i = item.iterator(); |
228 | Iterator<Integer> o = other.item.iterator(); |
229 | int n = item.size(); |
230 | while (i.hasNext() && o.hasNext() && --n > 0 ) |
231 | if (i.next() != o.next()) |
232 | return false ; |
233 |
234 | return !(item.last() == other.item.last()); |
235 | } |
236 | } |
这个代码只是计算了频繁项集,还没有计算关联规则和置信度,稍后补上。