关联分析-Apriori算法Java实现 支持度+置信度(1)

apriori算法是最基本的发现频繁项集的算法,它的名字也体现了它的思想——先验,采用逐层搜索迭代的方法,挖掘任何可能的项集,k项集用于挖掘k+1项集。

先验性质

频繁项集的所有非空子集也一定是频繁的

该性质体现了项集挖掘中的反单调性,如果k项集不是频繁的,那么k+1项集一定也不是。基于这一点,算法的基本思想为:

step 1:连接

    为了搜索k项集,将k-1项集自连接产生候选k项集,称为候选集。

    为了有效的实现连接,首先对每一项进行排序。其次,若满足连接的条件,则进行连接。

    连接的条件,前k-2项相同,k-1项不同

step 2:剪枝

    k项集的每一个k-1项子集都存在与k-1项集,并且支持度满足最小支持度阀值。

伪代码:

C<k>:candidata itemset of size k 
L<k>:frequent itemset of size k 
L<1>=frequent items 
 for(k=1;L<k>!=null;k++) 
    C<k+1>=candidates generated from L<k> 
    for transaction t in dataset 
        increment the count of all candidates in C<k+1> that are contained in t     L<k+1>=candidates in C<k+1> with support>=min_support 
return
Java代码实现方式:

抽象了一个项集实体类,并实现是否可以合并的方法,这个方法最初是使用TreeSet.headSet来实现的,但是在测试时发现性能瓶颈都产生在这个方法上,并造成OOM,很是费解,待研究清楚后总结一下。

001 /**
002  *
003  */
004 package org.waitingfortime.datamining.association;
005  
006 import java.io.BufferedReader;
007 import java.io.File;
008 import java.io.FileNotFoundException;
009 import java.io.FileOutputStream;
010 import java.io.FileReader;
011 import java.io.IOException;
012 import java.io.PrintStream;
013 import java.util.ArrayList;
014 import java.util.HashMap;
015 import java.util.Iterator;
016 import java.util.List;
017 import java.util.Map;
018 import java.util.Set;
019 import java.util.TreeSet;
020  
021 /**
022  * @author mazhiyuan
023  *
024  */
025 public class Apriori {
026     private int minNum;// 最小支持数
027     private List<Set<Integer>> records;
028     private String output;
029     private List<List<ItemSet>> result = new ArrayList<List<ItemSet>>();
030  
031     public Apriori(double minDegree, String input, String output) {
032         this.output = output;
033         init(input);
034         if (records.size() == 0) {
035             System.err.println("不符合计算条件。退出!");
036             System.exit(1);
037         }
038         minNum = (int) (minDegree * records.size());
039     }
040  
041     private void init(String path) {
042         // TODO Auto-generated method stub
043         records = new ArrayList<Set<Integer>>();
044         try {
045             BufferedReader br = new BufferedReader(new FileReader(
046                     new File(path)));
047  
048             String line = null;
049             Set<Integer> record;
050             while ((line = br.readLine()) != null) {
051                 if (!"".equals(line.trim())) {
052                     record = new TreeSet<Integer>();
053                     String[] items = line.split(" ");
054                     for (String item : items) {
055                         record.add(Integer.valueOf(item));
056                     }
057                     records.add(record);
058                 }
059             }
060  
061             br.close();
062         catch (IOException e) {
063             System.err.println("读取事务文件失败。");
064         }
065     }
066  
067     private List<ItemSet> first() {
068         // TODO Auto-generated method stub
069         List<ItemSet> first = new ArrayList<ItemSet>();
070         Map<Integer, Integer> _first = new HashMap<Integer, Integer>();
071         for (Set<Integer> si : records)
072             for (Integer i : si) {
073                 if (_first.get(i) == null)
074                     _first.put(i, 1);
075                 else
076                     _first.put(i, _first.get(i) + 1);
077             }
078  
079         for (Integer i : _first.keySet())
080             if (_first.get(i) >= minNum)
081                 first.add(new ItemSet(i, _first.get(i)));
082  
083         return first;
084     }
085  
086     private void loop(List<ItemSet> items) {
087         // TODO Auto-generated method stub
088         List<ItemSet> copy = new ArrayList<ItemSet>(items);
089         List<ItemSet> res = new ArrayList<ItemSet>();
090         int size = items.size();
091  
092         // 连接
093         for (int i = 0; i < size; i++)
094             for (int j = i + 1; j < size; j++)
095                 if (copy.get(i).isMerge(copy.get(j))) {
096                     ItemSet is = new ItemSet(copy.get(i));
097                     is.merge(copy.get(j).item.last());
098                     res.add(is);
099                 }
100         // 剪枝
101         pruning(copy, res);
102  
103         if (res.size() != 0) {
104             result.add(res);
105             loop(res);
106         }
107     }
108  
109     private void pruning(List<ItemSet> pre, List<ItemSet> res) {
110         // TODO Auto-generated method stub
111         // step 1 k项集的子集属于k-1项集
112         Iterator<ItemSet> ir = res.iterator();
113         while (ir.hasNext()) {
114             // 获取所有k-1项子集
115             ItemSet now = ir.next();
116             List<List<Integer>> ss = subSet(now);
117             // 判断是否在pre集中
118             boolean flag = false;
119             for (List<Integer> li : ss) {
120                 if (flag)
121                     break;
122                 for (ItemSet pis : pre) {
123                     if (pis.item.containsAll(li)) {
124                         flag = false;
125                         break;
126                     }
127                     flag = true;
128                 }
129             }
130             if (flag) {
131                 ir.remove();
132                 continue;
133             }
134             // step 2 支持度
135             int i = 0;
136             for (Set<Integer> sr : records) {
137                 if (sr.containsAll(now.item))
138                     i++;
139  
140                 now.value = i;
141             }
142             if (now.value < minNum)
143                 ir.remove();
144         }
145     }
146  
147     private List<List<Integer>> subSet(ItemSet is) {
148         // TODO Auto-generated method stub
149         List<Integer> li = new ArrayList<Integer>(is.item);
150         List<List<Integer>> res = new ArrayList<List<Integer>>();
151         for (int i = 0, j = li.size(); i < j; i++) {
152             List<Integer> _li = new ArrayList<Integer>(li);
153             _li.remove(i);
154             res.add(_li);
155         }
156         return res;
157     }
158  
159     private void output() throws FileNotFoundException {
160         if (result.size() == 0) {
161             System.err.println("无结果集。退出!");
162             return;
163         }
164         FileOutputStream out = new FileOutputStream(output);
165         PrintStream ps = new PrintStream(out);
166         for (List<ItemSet> li : result) {
167             ps.println("=============频繁"+li.get(0).item.size()+"项集=============");
168             for (ItemSet is : li)
169                 ps.println(is.item + " : " + is.value);
170             ps.println("=====================================");
171         }
172     }
173  
174     /**
175      * @param args
176      * @throws FileNotFoundException
177      */
178     public static void main(String[] args) throws FileNotFoundException {
179         // TODO Auto-generated method stub
180         long begin = System.currentTimeMillis();
181         Apriori apriori = new Apriori(0.25,
182                 "/home/mazhiyuan/code/mushroom.dat",
183                 "/home/mazhiyuan/code/mout.data");
184         // apriori.first();//频繁1项集
185         apriori.loop(apriori.first());
186         apriori.output();
187         System.out.println((System.currentTimeMillis()) - begin);
188     }
189 }
190  
191 class ItemSet {
192     TreeSet<Integer> item;
193     int value;
194  
195     ItemSet(ItemSet is) {
196         this.item = new TreeSet<Integer>(is.item);
197     }
198  
199     ItemSet() {
200         item = new TreeSet<Integer>();
201     }
202  
203     ItemSet(int i, int v) {
204         this();
205         merge(i);
206         setValue(v);
207     }
208  
209     void setValue(int i) {
210         this.value = i;
211     }
212  
213     void merge(int i) {
214         item.add(i);
215     }
216  
217     boolean isMerge(ItemSet other) {
218         if (other == null || other.item.size() != item.size())
219             return false;
220         // 前k-1项相同,最后一项不同,满足连接条件
221         /*
222          * Iterator<Integer> i = item.headSet(item.last()).iterator();
223          * Iterator<Integer> o =
224          * other.item.headSet(other.item.last()).iterator(); while (i.hasNext() &&
225          * o.hasNext()) if (i.next() != o.next()) return false;
226          */
227         Iterator<Integer> i = item.iterator();
228         Iterator<Integer> o = other.item.iterator();
229         int n = item.size();
230         while (i.hasNext() && o.hasNext() && --n > 0)
231             if (i.next() != o.next())
232                 return false;
233  
234         return !(item.last() == other.item.last());
235     }
236 }
使用mushroom数据集,整个运行时间只有大概6s,性能还算满意。

这个代码只是计算了频繁项集,还没有计算关联规则和置信度,稍后补上。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值