这是一份用JAVA实现的Apriori算法,由于是完成的课程作业所以没有考虑代码的优化,算法的背景就不介绍了,核心步骤在于剪枝和判断剪枝后的候选项集的所有子集是否满足要求,在获取指定长度子集时有一些技巧,具体请看代码。其中项集用HashMap<Set<String>,integer>来表示,关键字用Set集合可以自动排序,值用于记录项集在原始事物数据中出现的次数。原始数据用文件方式读取,注意文件内容每一行为一个原始事物项,不需要输入事物的编号。参考数据集为数据挖掘教材(韩家炜)P163。
package datamining;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class Apriori {
//剪枝函数
public ArrayList<Set<String>> apriori_gen(HashMap<Set<String>, Integer> L_last, int last_index){
ArrayList<Set<String>> result = new ArrayList<Set<String>>(); //存储剪枝后的结果
ArrayList<ArrayList<String>> item_set = null;
item_set = get_item_set(L_last); //获取上一个频繁项的所有项集,并转为字符串List
for(int i = 0; i < item_set.size() - 1; i++) {
ArrayList<String> str = item_set.get(i);
for(int j = i + 1; j < item_set.size(); j++) {
Set<String> new_item = new HashSet<String>(); //存储新的候选项集
ArrayList<String> str2 = item_set.get(j);
int length = str.size();
for(int k = 0; k < length - 1; k++) { //进行join操作
if(!str.get(k).equals(str2.get(k)))
break;
else
new_item.add(str.get(k));
}
new_item.add(str.get(str.size()-1));
new_item.add(str2.get(str2.size()-1));
if(new_item.size() == length + 1 && has_infrequent_subset(new_item, item_set, last_index)) //判断新的候选项集是否满足所有K-1项子集要求
result.add(new_item); //满足则加入结果集
}
}
return result;
}
//判断新的item的所有K-1项子集是否在上一个频繁项中都出现
public boolean has_infrequent_subset(Set<String> candidate, ArrayList<ArrayList<String>> last_item_set, int last_index) {
boolean flag = true;
ArrayList<ArrayList<String>> sub_set = get_subset(candidate, last_index);
// for(int j = 0; j < sub_set.size(); j++) {
// System.out.println(sub_set.get(j));
// }
for(int i = 0; i < sub_set.size(); i++) {
ArrayList<String> item = sub_set.get(i);
int j = 0;
for(j = 0; j < last_item_set.size(); j++) {
if(last_item_set.get(j).equals(item))
break;
}
if( j == last_item_set.size()) flag = false;
}
return flag;
}
//获取候选项集的K-1项所有子集
public ArrayList<ArrayList<String>> get_subset(Set<String> candidate, int index){
ArrayList<ArrayList<String>> sub_set = new ArrayList<ArrayList<String>>();
ArrayList<String> item_set = new ArrayList<String>();
Iterator iter = candidate.iterator();
while(iter.hasNext())
item_set.add((String)iter.next());
if(index == 1) { //当index等于1时单独考虑
for(int k = 0; k < item_set.size(); k++) {
ArrayList<String> buffer = new ArrayList<String>();
buffer.add(item_set.get(k));
sub_set.add(buffer);
}
}else {
for(int i = 0; i < item_set.size() - index + 1; i++) {
for(int j = i + 1; j < item_set.size(); j++) {
ArrayList<String> buffer = new ArrayList<String>();
buffer.add(item_set.get(i));
for(int k = 0; k < index - 1; k++) { //关键步骤,循环index-1次
if((k + j) < item_set.size())
buffer.add(item_set.get(k+j));
}
if(buffer.size() == index)
sub_set.add(buffer);
}
}
}
return sub_set;
}
//获取上一个频繁项的所有项集并转为List方便处理
public ArrayList<ArrayList<String>> get_item_set(HashMap<Set<String>, Integer> L_last){
ArrayList<ArrayList<String>> result = new ArrayList<ArrayList<String>>();
Iterator iter = L_last.entrySet().iterator();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
Set<String> set = (Set<String>)entry.getKey();
Iterator iter2 = set.iterator();
ArrayList<String> item = new ArrayList<String>();
while(iter2.hasNext()) {
String str = (String)iter2.next();
item.add(str);
}
result.add(item);
}
return result;
}
//处理原始事物数据
public HashMap<Set<String>, Integer> process_rawdata(ArrayList<Set<String>> raw_input, int min_sub){
HashMap<Set<String>, Integer> first_input = new HashMap<Set<String>, Integer>(); //存储处理后结果
//处理原始输入事物数据,统计每个单独事物的次数
for(int i = 0; i < raw_input.size(); i++) {
Set<String> item = raw_input.get(i);
Iterator iter = item.iterator();
while(iter.hasNext()) {
String str = (String)iter.next();
Set<String> single_item = new HashSet<String>();
single_item.add(str);
if(first_input.containsKey(single_item)) {
int count = first_input.get(single_item);
first_input.put(single_item, count+1);
}else
first_input.put(single_item, 1);
}
}
//移除单独事物出现次数少于min_sub的事物
for (Iterator<Map.Entry<Set<String>, Integer>> iter = first_input.entrySet().iterator(); iter.hasNext();){
Map.Entry<Set<String>, Integer> entry = iter.next();
Object key = entry.getKey();
int val = (int)entry.getValue();
if(val < min_sub){
iter.remove();
}
}
return first_input;
}
//计数函数,记录每个候选项集在事物数据中出现的次数
public int count_item(Set<String> item, ArrayList<Set<String>> raw_input) {
int count = 0;
Set<String> item2 = new HashSet<>(item);
for(int i = 0; i < raw_input.size(); i++){
Set<String> item_set = new HashSet<String>(raw_input.get(i));
item_set.retainAll(item2);
if(item_set.size() == item2.size())
count++;
}
return count;
}
//算法主函数
public List<HashMap<Set<String>, Integer>> apriori_main(ArrayList<Set<String>> raw_input, int min_sub){
int last_index = 1;
List<HashMap<Set<String>, Integer>> results = new ArrayList<HashMap<Set<String>, Integer>>(); //存储最终结果
HashMap<Set<String>, Integer> first_input = process_rawdata(raw_input, min_sub); //获取第一个频繁项集
ArrayList<Set<String>> candidates = apriori_gen(first_input, last_index); //获取第二个候选项集
while(!(candidates.size() == 0)) { //循环终止条件,无法选出下一个候选集合为止
HashMap<Set<String>, Integer> result = new HashMap<Set<String>, Integer>();
for(int i = 0; i < candidates.size(); i++) {
int count = count_item(candidates.get(i), raw_input); //计算每个候选项集在原始事物数据中出现次数
if(count >= min_sub)
result.put(candidates.get(i), count); //将满足结果的加入结果集中
}
if(result.size() > 0)
results.add(result);
last_index++; //索引加1
candidates = apriori_gen(result, last_index); //计算下一个候选项集合
}
return results;
}
public static void main(String args[]) throws IOException {
ArrayList<Set<String>> raw_data = new ArrayList<Set<String>>(); //存储原始数据
File file = new File(".\\data\\apriori.txt"); //获取外部原始事物数据
BufferedReader reader = new BufferedReader(new FileReader(file));
String string = "";
while((string = reader.readLine())!=null){
Set<String> item = new HashSet<String>();
String[] items = string.split(",");
for(int i = 0; i < items.length; i++)
item.add(items[i]);
raw_data.add(item);
}
Apriori apriori = new Apriori();
List<HashMap<Set<String>, Integer>> result = apriori.apriori_main(raw_data, 2); //定义min_sub为2
System.out.println(result.get(result.size()-1)); //输出最后结果
}
}