package AprioriTest;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
/**
* 功能:Apriori算法的实现
* 时间:2012.09.09
* 作者:zhouhong
*/
public class AprioriTest01 {
private int minSup;
private static List data;
private static List> dataSet;
public static void main(String[] args) {
long startTime = System.currentTimeMillis();
AprioriTest01 apriori = new AprioriTest01();
//使用书中的测试集
//apriori.setMinSup(2);
data = apriori.buildData();
//设置最小支持度
apriori.setMinSup(2);
//构造数据集
data = apriori.buildData();
//构造频繁1项集
List> f1Set = apriori.findF1Item(data);
apriori.printSet(f1Set, 1);
List> result = f1Set;
int i = 2;
do{
result = apriori.arioriGen(result);
apriori.printSet(result, i);
i++;
}while(result.size() != 0);
long endTime = System.currentTimeMillis();
System.out.println("共用时: " + (endTime - startTime) + "ms");
}
public void setMinSup(int minSup){
this.minSup = minSup;
}
/**
* 构造原始数据集,可以为之提供参数,也可以不提供
* 如果不提供参数,将按程序默认构造的数据集
* 如果提供参数为文件名,则使用文件中的数据集
*/
List buildData(String...fileName){
List data = new ArrayList();
if(fileName.length != 0){
File file = new File(fileName[0]);
try{
BufferedReader reader = new BufferedReader(new FileReader(file));
String line;
while( ( line = reader.readLine()) != null ){
data.add(line);
}
}catch (FileNotFoundException e){
e.printStackTrace();
}catch (IOException e){
e.printStackTrace();
}
}else{
data.add("I1 I2 I5");
data.add("I2 I4");
data.add("I2 I3");
data.add("I1 I2 T4");
data.add("I1 I3");
data.add("I2 I3");
data.add("I1 I3");
data.add("I1 I2 I3 I5");
data.add("I1 I2 I3");
}
dataSet = new ArrayList>();
Set dSet;
for(String d : data){
dSet = new TreeSet();
String[] dArr = d.split(" ");
for(String str : dArr){
dSet.add(str);
}
dataSet.add(dSet);
}
return data;
}
/**
* 找出候选1项集
* @param data
* @return result
*/
List> findF1Item(List data){
List> result = new ArrayList>();
Map dc = new HashMap();
for(String d : data){
String[] items = d.split(" ");
for(String item : items){
if(dc.containsKey(item)) {
dc.put(item, dc.get(item)+1);
}else{
dc.put(item, 1);
}
}
}
Set itemKeys = dc.keySet();
Set tempKeys = new TreeSet();
for(String str : itemKeys){
tempKeys.add(str);
}
for(String item : tempKeys){
if(dc.get(item) >= minSup) {
Set f1Set = new TreeSet();
f1Set.add(item);
result.add(f1Set);
}
}
return result;
}
/**
* 利用arioriGen方法由k-1项集生成k项集
*@param preSet
*@return
*
*/
List> arioriGen(List> preSet) {
List> result = new ArrayList>();
int preSetSize = preSet.size();
for(int i = 0; i < preSetSize - 1; i++){
for(int j = i + 1; j < preSetSize; j++ ){
String[] strA1 = preSet.get(i).toArray(new String[0]);
String[] strA2 = preSet.get(j).toArray(new String[0]);
if(isCanLink(strA1, strA2)) {//判断两个k-1项集是否符合连接成K项集的条件
Set set = new TreeSet();
for(String str : strA1){
set.add(str);//将strA1加入set中连成前K-1项集
}
set.add((String) strA2[strA2.length-1]);//连接成K项集
//判断K项集是否需要剪切掉,如果不需要被cut掉,则加入到k项集的列表中
if(!isNeedCut(preSet, set)) {
result.add(set);
}
}
}
}
return checkSupport(result);//返回的都是频繁K项集
}
/**
* 把set中的项集与数量集比较并进行计算,求出支持度大于要求的项集
* @param set
* @return
*/
List> checkSupport(List > setList){
List> result = new ArrayList>();
boolean flag = true;
int [] counter = new int[setList.size()];
for(int i = 0; i < setList.size(); i++){
for(Set dSets : dataSet) {
if(setList.get(i).size() > dSets.size()){
flag = true;
}else{
for(String str : setList.get(i)){
if(!dSets.contains(str)){
flag = false;
break;
}
}
if(flag) {
counter[i] += 1;
} else{
flag = true;
}
}
}
}
for(int i = 0; i < setList.size(); i++){
if (counter[i] >= minSup) {
result.add(setList.get(i));
}
}
return result;
}
/**
* 判断两个项集能否执行连接操作
* @param s1
* @param s2
* @return
*/
boolean isCanLink(String [] s1, String[] s2){
boolean flag = true;
if(s1.length == s2.length) {
for(int i = 0; i < s1.length - 1; i ++){
if(!s1[i].equals(s2[i])){
flag = false;
break;
}
}
if(s1[s1.length - 1].equals(s2[s2.length - 1])){
flag = false;
}
}else{
flag = true;
}
return flag;
}
/**
* 判断set是否需要被cut
*
* @param setList
* @param set
* @return
*/
boolean isNeedCut(List> setList, Set set) {//setList指频繁K-1项集,set指候选K项集
boolean flag = false;
List> subSets = getSubset(set);//获得K项集的所有k-1项集
for ( Set subSet : subSets) {
//判断当前的k-1项集set是否在频繁k-1项集中出现,如果出现,则不需要cut
//若没有出现,则需要被cut
if( !isContained(setList, subSet)){
flag = true;
break;
}
}
return flag;
}
/**
* 功能:判断k项集的某k-1项集是否包含在频繁k-1项集列表中
*
* @param setList
* @param set
* @return
*/
boolean isContained(List> setList, Set set){
boolean flag = false;
int position = 0;
for( Set s : setList ) {
String [] sArr = s.toArray(new String[0]);
String [] setArr = set.toArray(new String[0]);
for(int i = 0; i < sArr.length; i++) {
if ( sArr[i].equals(setArr[i])){
//如果对应位置的元素相同,则position为当前位置的值
position = i;
} else{
break;
}
}
//如果position等于数组的长度,说明已经找到某个setList中的集合与
//set集合相同了,退出循环,返回包含
//否则,把position置为0进入下一个比较
if ( position == sArr.length - 1) {
flag = true;
break;
} else {
flag = false;
position = 0;
}
}
return flag;
}
/**
* 获得k项集的所有k-1项子集
*
* @param set
* @return
*/
List> getSubset(Set set){
List> result = new ArrayList>();
String [] setArr = set.toArray(new String[0]);
for( int i = 0; i < setArr.length; i++){
Set subSet = new TreeSet();
for(int j = 0; j < setArr.length; j++){
if( i != j){
subSet.add((String) setArr[j]);
}
}
result.add(subSet);
}
return result;
}
/**
* 功能:打印频繁项集
*/
void printSet(List> setList, int i){
System.out.print("频繁" + i + "项集: 共" + setList.size() + "项: {");
for(Set set : setList) {
System.out.print("[");
for(String str : set) {
System.out.print(str + " ");
}
System.out.print("], ");
}
System.out.println("}");
}
}
运行结果:
频繁1项集: 共4项: {[I1 ], [I2 ], [I3 ], [I5 ], } 频繁2项集: 共5项: {[I1 I2 ], [I1 I3 ], [I1 I5 ], [I2 I3 ], [I2 I5 ], } 频繁3项集: 共2项: {[I1 I2 I3 ], [I1 I2 I5 ], } 频繁4项集: 共0项: {}