1. 算法原理
Apriori关联规则算法的目的就是找出所有的频繁项集,所以需要定义一个评估标准找出频繁项集,即最小支持度。 首先从原始数据集中找出出现的所有项,对应数据集确定候选1项集,根据候选一项集每项在原始项集中的出现次数计算每一项的sup值。比较sup值 / 原始数据集数 的值与最小支持度,小于则舍去,计算出频繁一项集,然后对频繁一项集两项之间求补集,并按照一项集中求sup的方法求取候选二项集及频繁二项集。之后递归求取频繁n项集,当频繁项集项数只有一项时递归结束。得到最后的频繁项集。
2. 代码实现
import java.util.ArrayList;
/**
* @Description 项集item
* @Author Clxk
* @Date 2019/4/15 10:57
* @Version 1.0
*/
public class Data {
private ArrayList<String> data = new ArrayList<>();
private int cnt;
public ArrayList<String> getData() {
return data;
}
public void setData(ArrayList<String> data) {
this.data = data;
}
public int getCnt() {
return cnt;
}
public void setCnt(int cnt) {
this.cnt = cnt;
}
@Override
public boolean equals(Object obj) {
Data rhs = (Data) obj;
boolean eq = this.cnt == rhs.cnt;
if(this.cnt == rhs.cnt) {
for(int i = 0; i < data.size(); i++) {
if(!data.get(i).equals(rhs.data.get(i))) {
eq = false;
break;
}
}
}
return eq;
}
}
import java.util.*;
/**
* @Description Apriori
* @Author Clxk
* @Date 2019/4/15 10:43
* @Version 1.0
*/
public class Main {
/**
* 初始数据集最大值
*/
private static final int MAXN = 1000;
/**
* 数据集长度、最小支持度
*/
private static int datacnt = 0;
private static double minsupport = 0;
/**
* 初始数据集
*/
private static ArrayList<String> []data = new ArrayList[500];
/**
* 项集结构
*/
private static ArrayList<Data> items = new ArrayList<>();
public static void main(String[] args) {
/**
* 原始数据集读取
*/
Scanner scanner = new Scanner(System.in);
System.out.println("请输入数据集的大小: ");
datacnt = scanner.nextInt();
System.out.println("请输入最小支持度: ");
minsupport = scanner.nextDouble();
System.out.println("请输入原始数据集: ");
String str;
scanner.nextLine();
for (int i = 0; i < datacnt; i++) {
data[i] = new ArrayList<>();
str = scanner.nextLine();
String[] split = str.split("\\s");
for (int j = 0; j < split.length; j++) {
data[i].add(split[j]);
}
}
/**
* 数据集处理
*/
solve(data);
}
/**
* 数据集处理
* @param data
*/
public static void solve(ArrayList<String>[] data) {
getFrequent(data, 1);
}
/**
* 获取到频繁1项集
* @param data
*/
public static void getFrequentOne(ArrayList<String>[] data) {
/**
* 获取不重复集合
*/
for(ArrayList<String> list : data) {
if(list == null) break;
for(String s: list) {
Data dt = new Data();
List<String> ls = new ArrayList<>();
ls.add(s);
dt.setData((ArrayList<String>) ls);
int is_have = 0;
for(int i = 0; i < items.size(); i++) {
Data d = items.get(i);
if(d.getData().equals(ls)) {
is_have = 1;
break;
}
}
if(is_have == 0) {
items.add(dt);
}
}
}
}
/**
* 输出候选n项集
* @param n
*/
public static void getCandidate(int n) {
System.out.println("候选" + n + "项集为: ");
outList();
}
/**
* 输出频繁n项集
* @param n
*/
public static void getItems(int n) {
for(int i = 0; i < items.size(); i++) {
if((double)items.get(i).getCnt() / datacnt < minsupport) {
items.remove(i);
i--;
}
}
System.out.println("频繁"+ n +"项集为: ");
outList();
}
/**
* 获取频繁n项集
* @param data
* @param n
*/
public static void getFrequent(ArrayList<String>[] data, int n) {
if(n == 1) {
getFrequentOne(data);
} else {
ArrayList<Data> array = new ArrayList<>();
for(int i = 0; i < items.size(); i++) {
Set<String> set = new HashSet<>();
ArrayList<String> data1 = items.get(i).getData();
for(int j = i+1; j < items.size(); j++) {
set.clear();
ArrayList<String> data2 = items.get(j).getData();
for(int u = 0; u < Math.max(data1.size(), data2.size()); u++) {
if(data1.size() > u) set.add(data1.get(u));
if(data2.size() > u) set.add(data2.get(u));
}
if(set == null || set.size() != n) continue;
put2Items(array,set);
}
}
items = (ArrayList<Data>) array.clone();
}
/**
* 获取sup值
*/
addSup(n);
/**
* 输出候选n项集
*/
getCandidate(n);
/**
* 输出频繁n项集
*/
getItems(n);
if(items.size() > 1) {
getFrequent(data, n+1);
}
}
/**
* 获取Sup值
* @param n
*/
public static void addSup(int n) {
for(int i = 0; i < items.size(); i++) {
ArrayList<String> list = items.get(i).getData();
int cnt = 0;
for(int j = 0; j < datacnt; j++) {
int have = 1;
ArrayList<String> cur = data[j];
for(int u = 0; u < list.size(); u++) {
if(!cur.contains(list.get(u))) {
have = 0;
break;
}
}
if(have == 1) cnt++;
}
Data d = new Data();
d.setData(list);
d.setCnt(cnt);
items.set(i, d);
}
}
/**
* 整理候选频繁项集,同项相加
* @param array,set
*/
public static void put2Items(ArrayList<Data> array, Set<String> set) {
Data data = new Data();
for(String s:set) {
data.getData().add(s);
}
int is_have = 0;
for(int i = 0; i < array.size(); i++) {
if(array.get(i) == null) break;
if(array.get(i).equals(data)) {
is_have = 1;
array.set(i, data);
break;
}
}
if(is_have == 0) {
array.add(data);
}
}
/**
* 输出项集
*/
public static void outList() {
for(Data data : items) {
System.out.println(Arrays.toString(data.getData().toArray()) + " " + data.getCnt());
}
}
}