这个一个apriori算法的演示版本,所有的代码都在一个类。仅供研究算法参考
package test;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Vector;
//用set写的apriori算法
public class AprioriSetBasedDemo {
class Transaction {
/*
* 购物记录,用set保存多个货物名
*/
private HashSet<String> pnSet = new HashSet<String>();
public Transaction() {
pnSet.clear();
}
public Transaction(String[] names) {
pnSet.clear();
for (String s : names) {
pnSet.add(s);
}
}
public HashSet<String> getPnSet() {
return pnSet;
}
public void addPname(String s) {
pnSet.add(s);
}
public boolean containSubSet(HashSet<String> subSet) {
return pnSet.containsAll(subSet);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Iterator<String> iter = pnSet.iterator();
while (iter.hasNext()) {
sb.append(iter.next() + ",");
}
return "Transaction = [" + sb.toString() + "]";
}
}
class TransactionDB {
// 记录所有的Transaction
private Vector<Transaction> vt = new Vector<Transaction>();
public TransactionDB() {
vt.clear();
}
public int getSize() {
return vt.size();
}
public void addTransaction(Transaction t) {
vt.addElement(t);
}
public Transaction getTransaction(int idx) {
return vt.elementAt(idx);
}
}
public class AssoRule implements Comparable<AssoRule> {
private String ruleContent;
private double confidence;
public void setRuleContent(String ruleContent) {
this.ruleContent = ruleContent;
}
public void setConfidence(double confidence) {
this.confidence = confidence;
}
public AssoRule(String ruleContent, double confidence) {
this.ruleContent = ruleContent;
this.confidence = confidence;
}
@Override
public int compareTo(AssoRule o) {
if (o.confidence > this.confidence) {
return 1;
} else if (o.confidence == this.confidence) {
return 0;
} else {
return -1;
}
}
@Override
public String toString() {
return ruleContent + ", confidence=" + confidence * 100 + "%";
}
}
public static String getStringFromSet(HashSet<String> set) {
StringBuilder sb = new StringBuilder();
Iterator<String> iter = set.iterator();
while (iter.hasNext()) {
sb.append(iter.next() + ", ");
}
if (sb.length() > 2) {
sb.delete(sb.length() - 2, sb.length() - 1);
}
return sb.toString();
}
// 计算具有最小支持度的一项频繁集 >= minSupport
public static HashMap<String, Integer> buildMinSupportFrequenceSet(
TransactionDB tdb, int minSupport) {
HashMap<String, Integer> minSupportMap = new HashMap<String, Integer>();
for (int i = 0; i < tdb.getSize(); i++) {
Transaction t = tdb.getTransaction(i);
Iterator<String> it = t.getPnSet().iterator();
while (it.hasNext()) {
String key = it.next();
if (minSupportMap.containsKey(key)) {
minSupportMap.put(key, minSupportMap.get(key) + 1);
} else {
minSupportMap.put(key, new Integer(1));
}
}
}
Iterator<String> iter = minSupportMap.keySet().iterator();
Vector<String> toBeRemoved = new Vector<String>();
while (iter.hasNext()) {
String key = iter.next();
if (minSupportMap.get(key) < minSupport) {
toBeRemoved.add(key);
}
}
for (int i = 0; i < toBeRemoved.size(); i++) {
minSupportMap.remove(toBeRemoved.get(i));
}
return minSupportMap;
}
public void buildRules(TransactionDB tdb,
HashMap<HashSet<String>, Integer> kItemFS, Vector<AssoRule> var,
double ruleMinSupportPer) {
// 如果kItemFS的成员数量不超过1不需要计算
if (kItemFS.size() <= 1) {
return;
}
// k+1项频项集
HashMap<HashSet<String>, Integer> kNextItemFS = new HashMap<HashSet<String>, Integer>();
// 获得第k项频项集
@SuppressWarnings("unchecked")
HashSet<String>[] kItemSets = new HashSet[kItemFS.size()];
kItemFS.keySet().toArray(kItemSets);
/*
* 根据k项频项集,用两重循环获得k+1项频项集 然后计算有多少个tranction包含这个k+1项频项集
* 然后支持比超过ruleMinSupportPer,就可以生成规则,放入规则向量
* 然后,将k+1项频项集及其支持度放入kNextItemFS,进入下一轮计算
*/
for (int i = 0; i < kItemSets.length - 1; i++) {
HashSet<String> set_i = kItemSets[i];
for (int j = i + 1; j < kItemSets.length; j++) {
HashSet<String> set_j = kItemSets[j];
// k+1 item set
HashSet<String> kNextSet = new HashSet<String>();
kNextSet.addAll(set_i);
kNextSet.addAll(set_j);
if (kNextSet.size() <= set_i.size()
|| kNextSet.size() <= set_j.size()) {
continue;
}
// 计算k+1 item set在所有transaction出现了几次
int count = 0;
for (int k = 0; k < tdb.getSize(); k++) {
if (tdb.getTransaction(k).containSubSet(kNextSet)) {
count++;
}
}
if (count <= 0) {
continue;
}
Integer n_i = kItemFS.get(set_i);
double per = 1.0 * count / n_i.intValue();
if (per >= ruleMinSupportPer) {
kNextItemFS.put(kNextSet, new Integer(count));
HashSet<String> tmp = new HashSet<String>();
tmp.addAll(kNextSet);
tmp.removeAll(set_i);
String s1 = "{" + getStringFromSet(set_i) + "}" + "(" + n_i
+ ")" + "==>" + getStringFromSet(tmp).toString()
+ "(" + count + ")";
var.addElement(new AssoRule(s1, per));
}
}
}
// 进入下一轮计算
buildRules(tdb, kNextItemFS, var, ruleMinSupportPer);
}
public void test() {
// Transaction数据集
TransactionDB tdb = new TransactionDB();
// 添加Transaction交易记录
tdb.addTransaction(new Transaction(new String[] { "a", "b", "c", "d" }));
tdb.addTransaction(new Transaction(new String[] { "a", "b" }));
tdb.addTransaction(new Transaction(new String[] { "b", "c" }));
tdb.addTransaction(new Transaction(new String[] { "b", "c", "d", "e" }));
// 规则最小支持度
double minRuleConfidence = 0.5;
Vector<AssoRule> vr = computeAssociationRules(tdb, minRuleConfidence);
// 输出规则
int i = 0;
for (AssoRule ar : vr) {
System.out.println("rule[" + (i++) + "]: " + ar);
}
}
public Vector<AssoRule> computeAssociationRules(TransactionDB tdb,
double ruleMinSupportPer) {
// 输出关联规则
Vector<AssoRule> var = new Vector<AssoRule>();
// 计算最小支持度频项
HashMap<String, Integer> minSupportMap = buildMinSupportFrequenceSet(
tdb, 2);
// 计算一项频项集
HashMap<HashSet<String>, Integer> oneItemFS = new HashMap<HashSet<String>, Integer>();
for (String key : minSupportMap.keySet()) {
HashSet<String> oneItemSet = new HashSet<String>();
oneItemSet.add(key);
oneItemFS.put(oneItemSet, minSupportMap.get(key));
}
// 根据一项频项集合,递归计算规则
buildRules(tdb, oneItemFS, var, ruleMinSupportPer);
// 将规则按照可信度排序
Collections.sort(var);
return var;
}
public static void main(String[] args) {
AprioriSetBasedDemo asbd = new AprioriSetBasedDemo();
asbd.test();
}
}
运行结果如下:
rule[0]: {d }(2)==>b (2), confidence=100.0%
rule[1]: {d }(2)==>c (2), confidence=100.0%
rule[2]: {d, a }(1)==>c (1), confidence=100.0%
rule[3]: {d, a }(1)==>b (1), confidence=100.0%
rule[4]: {d, a }(1)==>b (1), confidence=100.0%
rule[5]: {d, c }(2)==>b (2), confidence=100.0%
rule[6]: {d, b, a }(1)==>c (1), confidence=100.0%
rule[7]: {d, b, a }(1)==>c (1), confidence=100.0%
rule[8]: {d, c, a }(1)==>b (1), confidence=100.0%
rule[9]: {b }(4)==>c (3), confidence=75.0%
rule[10]: {b, c }(3)==>d (2), confidence=66.66666666666666%
rule[11]: {b, c }(3)==>d (2), confidence=66.66666666666666%
rule[12]: {d }(2)==>a (1), confidence=50.0%
rule[13]: {b }(4)==>a (2), confidence=50.0%
rule[14]: {d, c }(2)==>b, a (1), confidence=50.0%
rule[15]: {d, b }(2)==>a (1), confidence=50.0%