简单说明
学院开了一门课《数据挖掘与机器学习》,要求我们计算机1、2两个班的全部同学选修这门课,包括课程实验。教材采用王振武、徐慧编著的《数据挖掘算法原理与实现》。教材里面提供的代码是C++代码,而由于本人更习惯使用Java语言编程,为了深入理解算法原理和过程,完成实验任务,于是用Java语言实现了Apriori关联规则挖掘算法。
Apriori算法
Apriori算法的基本思想是通过对数据库的多次扫描来计算项集的支持度,发现所有的频繁项集从而生成关联规则。
其实就是从一堆数据里面找出出现次数最多的数据组合,找出来的组合就是强关联的。
产生频繁项集的过程包括连接和剪枝两步。
连接步:
假设有两个有序3-项集L1 = (A, B, C),L2 = (A, B, D)。则L1和L2可连接产生4-项集C1 = (A, B, C, D)。
剪枝步:
频繁k-项集的任何自己必须是频繁项集,根据这个性质去除连接步产生的不满足支持度的k-项集。
代码如下:
//Item.java
import java.util.ArrayList;
/**
* 项集
*/
@SuppressWarnings("hiding")
public class Item<String> extends ArrayList<String> {
private static final long serialVersionUID = 1L;
/**
* 判断本项集与next项集是否可连接
*
* @param next
* @return
*/
public boolean linkable(Item<String> next) {
if (this.size() != next.size())
return false;
for (int i = 0; i < this.size() - 1; i++) {
if (!get(i).equals(next.get(i)))
return false;
}
return true;
}
/**
* 对项集去重
*/
public void unique() {
String s = get(0);
for (int i = 1; i < size(); i++) {
String t = get(i);
while (t.equals(s)) {
remove(t);
if (i < size())
t = get(i);
else {
break;
}
}
s = t;
}
}
}
//Apriori.java
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
/**
* 算法实体
*/
public class Apriori {
private HashMap<String, Integer> oneElementSet; // 一项集
private ArrayList<Item<String>> sourceItems; // 原始数据
private ArrayList<HashMap<Item<String>, Integer>> rankFrequentSets; // 各级频繁项集
private int minValue; // 最小阈值
Apriori(int size, int minValue) {
oneElementSet = new HashMap<>();
sourceItems = new Item<>();
rankFrequentSets = new Item<>();
this.minValue = minValue;
}
/**
* 添加项集
*
* @param item
*/
public void addItem(Item<String> item) {
// 对项集排序后添加
item.sort(new Comparator<String>() {
@Override
public int compare(String arg0, String arg1) {
return arg0.compareTo(arg1);
}
});
sourceItems.add(item);
}
public ArrayList<HashMap<Item<String>, Integer>> getRankFrequentSets() {
return rankFrequentSets;
}
/**
* 找出一项集
*
* @return
*/
public HashMap<String, Integer> findOneElementItems() {
for (Item<String> list : sourceItems) {
for (String s : list) {
if (!oneElementSet.containsKey(s)) {
oneElementSet.put(s, 1);
} else {
oneElementSet.put(s, oneElementSet.get(s) + 1);
}
}
}
return oneElementSet;
}
/**
* 产生频繁一项集
*
* @return
*/
public HashMap<Item<String>, Integer> obtainFrequentOneElementSet() {
HashMap<Item<String>, Integer> map = new HashMap<>();
for (String key : oneElementSet.keySet()) {
int value = oneElementSet.get(key);
if (value >= minValue) {
Item<String> item = new Item<>();
item.add(key);
map.put(item, value);
}
}
rankFrequentSets.add(0, map);
return map;
}
/**
* 产生频繁K项集 剪枝步
*
* @param k
* @return
*/
public HashMap<Item<String>, Integer> obtainFrequentSet(int k) {
Item<Item<String>> items = link(k);
HashMap<Item<String>, Integer> freSet = new HashMap<>();
for (Item<String> item : items) {
int count = 0;
for (Item<String> source : sourceItems) {
boolean flag = true;
for (String s : item) {
if (!source.contains(s)) {
flag = false;
break;
}
}
if (flag) {
count++;
}
}
if (count >= minValue) {
freSet.put(item, count);
}
}
if (freSet.size() <= 0)
return null;
rankFrequentSets.add(k - 1, freSet);
return freSet;
}
/**
* 连接产生K项集
*
* @param k
* @return
*/
public Item<Item<String>> link(int k) {
Item<Item<String>> items = new Item<>();
HashMap<Item<String>, Integer> map = rankFrequentSets.get(k - 2);
Set<Item<String>> keys = map.keySet();
Iterator<Item<String>> iterator = keys.iterator();
if (k == 2) {
for (int i = 0; i < keys.size(); i++) {
Item<String> item = iterator.next();
Iterator<Item<String>> iterator2 = keys.iterator();
for (int j = 0; j < i + 1; j++) {
iterator2.next();
}
for (int j = i + 1; j < keys.size(); j++) {
Item<String> item2 = iterator2.next();
Item<String> instance = new Item<>();
instance.add(item.get(0));
instance.add(item2.get(0));
items.add(instance);
}
}
return items;
} else {
for (int i = 0; i < keys.size() - 1; i++) {
Item<String> item = iterator.next();
Iterator<Item<String>> iterator2 = keys.iterator();
for (int j = 0; j < i + 1; j++) {
iterator2.next();
}
for (int j = i + 1; j < keys.size(); j++) {
Item<String> item2 = iterator2.next();
if (item.linkable(item2)) {
Item<String> instance = new Item<>();
for (int n = 0; n < k - 1; n++) {
instance.add(item.get(n));
}
instance.add(item2.get(k - 2));
items.add(instance);
}
}
}
}
return items;
}
}
//Main.java
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
int size;
int minValue;
Scanner scanner = new Scanner(System.in);
System.out.print("事务数:");
size = scanner.nextInt();
System.out.print("最小阈值:");
minValue = scanner.nextInt();
Apriori apriori = new Apriori(size, minValue);
scanner.nextLine();
for (int i = 0; i < size; i++) {
Item<String> item = new Item<>();
System.out.print("输入第" + (i + 1) + "项:");
String line = scanner.nextLine();
Scanner scanner2 = new Scanner(line);
while (scanner2.hasNext()) {
item.add(scanner2.next());
}
scanner2.close();
item.unique();//对输入的项集去重
apriori.addItem(item);
}
scanner.close();
HashMap<String, Integer> oneElementSet = apriori.findOneElementItems();
Iterator<String> iterator = oneElementSet.keySet().iterator();
while (iterator.hasNext()) {
String key = iterator.next();
System.out.println(key + ":" + oneElementSet.get(key));
}
apriori.obtainFrequentOneElementSet();
int k = 2;
while (apriori.obtainFrequentSet(k++) != null)
;
ArrayList<HashMap<Item<String>, Integer>> rankSets = apriori.getRankFrequentSets();
Item<String> item = null;
HashMap<Item<String>, Integer> map = null;
for (int i = 0; i < k - 2; i++) {
map = rankSets.get(i);
System.out.println("第 " + (i + 1) + " 级频繁项集:");
Iterator<Item<String>> iterator2 = map.keySet().iterator();
while (iterator2.hasNext()) {
item = iterator2.next();
System.out.print("{ ");
for (String s : item) {
System.out.print(s + " ");
}
System.out.print("}\t");
System.out.println(map.get(item));
}
}
System.out.println("最终频繁项集:");
Iterator<Item<String>> iterator2 = map.keySet().iterator();
while (iterator2.hasNext()) {
item = iterator2.next();
System.out.print("{ ");
for (String s : item) {
System.out.print(s + " ");
}
System.out.print("}\t");
System.out.println(map.get(item));
}
}
}