介绍
Apriori算法是一个经典的数据挖掘算法,Apriori的单词的意思是"先验的",说明这个算法是具有先验性质的,就是说要通过上一次的结果推导出下一次的结果,这个如何体现将会在下面的分析中会慢慢的体现出来。Apriori算法的用处是挖掘频繁项集的,频繁项集粗俗的理解就是找出经常出现的组合,然后根据这些组合最终推出我们的关联规则。
Apriori算法原理
Apriori算法是一种逐层搜索的迭代式算法,其中k项集用于挖掘(k+1)项集,这是依靠他的先验性质的:
频繁项集的所有非空子集一定是也是频繁的。
通过这个性质可以对候选集进行剪枝。用k项集如何生成(k+1)项集呢,这个是算法里面最难也是最核心的部分。
通过2个步骤
1、连接步,将频繁项自己与自己进行连接运算。
2、剪枝步,去除候选集项中的不符合要求的候选项,不符合要求指的是这个候选项的子集并非都是频繁项,要遵守上文提到的先验性质。
3、通过1,2步骤还不够,在后面还要根据支持度计数筛选掉不满足最小支持度数的候选集。
交易ID | 商品ID列表 |
T100 | I1,I2,I5 |
T200 | I2,I4 |
T300 | I2,I3 |
T400 | I1,I2,I4 |
T500 | I1,I3 |
T600 | I2,I3 |
T700 | I1,I3 |
T800 | I1,I2,I3,I5 |
T900 | I1,I2,I3 |
算法的代码实现如下:
package com.gdut.mahao;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
public class Apriori {
private final static int SUPPORT = 2; // 支持度阈值
private final static double CONFIDENCE = 0.7; // 置信度阈值
private final static String ITEM_SPLIT = ","; // 项之间的分隔符
private final static String CON = "-->"; // 项之间的分隔符
private final static List<String> transList = new ArrayList<String>(); // 所有交易
static {// 初始化交易记录,在apriori算法中,应保证项集中的项是有序的
transList.add("1,2,5,");
transList.add("2,4,");
transList.add("2,3,");
transList.add("1,2,4,");
transList.add("1,3,");
transList.add("2,3,");
transList.add("1,3,");
transList.add("1,2,3,5,");
transList.add("1,2,3,");
}
public Map<String, Integer> getFC() {
Map<String, Integer> frequentCollectionMap = new HashMap<String, Integer>();// 所有的频繁集
frequentCollectionMap.putAll(getItem1FC());
Map<String, Integer> itemkFcMap = new HashMap<String, Integer>();
itemkFcMap.putAll(getItem1FC());
while (itemkFcMap != null && itemkFcMap.size() != 0) {
Map<String, Integer> candidateCollection = getCandidateCollection(itemkFcMap);
Set<String> ccKeySet = candidateCollection.keySet();
// 对候选集项进行累加计数
for (String trans : transList) {
for (String candidate : ccKeySet) {
boolean flag = true;// 用来判断交易中是否出现该候选项,如果出现,计数加1
String[] candidateItems = candidate.split(ITEM_SPLIT);
for (String candidateItem : candidateItems) {
if (trans.indexOf(candidateItem + ITEM_SPLIT) == -1) {
flag = false;
break;
}
}
if (flag) {
Integer count = candidateCollection.get(candidate);
candidateCollection.put(candidate, count + 1);
}
}
}
// 从候选集中找到符合支持度的频繁集项
itemkFcMap.clear();
for (String candidate : ccKeySet) {
Integer count = candidateCollection.get(candidate);
if (count >= SUPPORT) {
itemkFcMap.put(candidate, count);
}
}
// 合并所有频繁集
frequentCollectionMap.putAll(itemkFcMap);
}
return frequentCollectionMap;
}
private Map<String, Integer> getCandidateCollection(
Map<String, Integer> itemkFcMap) {
Map<String, Integer> candidateCollection = new HashMap<String, Integer>();
Set<String> itemkSet1 = itemkFcMap.keySet();
Set<String> itemkSet2 = itemkFcMap.keySet();
for (String itemk1 : itemkSet1) {
for (String itemk2 : itemkSet2) {
// 进行连接
String[] tmp1 = itemk1.split(ITEM_SPLIT);
String[] tmp2 = itemk2.split(ITEM_SPLIT);
String c = "";
if (tmp1.length == 1) {//itemkFcMap存放的是候选1项集集合时
if (tmp1[0].compareTo(tmp2[0]) < 0) {
c = tmp1[0] + ITEM_SPLIT + tmp2[0] + ITEM_SPLIT;
}
} else {
boolean flag = true;//是否可以进行连接
for (int i = 0; i < tmp1.length - 1; i++) {
if (!tmp1[i].equals(tmp2[i])) {
flag = false;
break;
}
}
if (flag && (tmp1[tmp1.length - 1].compareTo(tmp2[tmp2.length - 1]) < 0)) {
c = itemk1 + tmp2[tmp2.length - 1] + ITEM_SPLIT;
}
}
// 进行剪枝
boolean hasInfrequentSubSet = false;// 是否有非频繁子项集,默认无
if (!c.equals("")) {
String[] tmpC = c.split(ITEM_SPLIT);
//忽略的索引号ignoreIndex
for (int ignoreIndex = 0; ignoreIndex < tmpC.length; ignoreIndex++) {
String subC = "";
for (int j = 0; j < tmpC.length; j++) {
if (ignoreIndex != j) {
subC += tmpC[j] + ITEM_SPLIT;
}
}
if (itemkFcMap.get(subC) == null) {
hasInfrequentSubSet = true;
break;
}
}
} else {
hasInfrequentSubSet = true;
}
if (!hasInfrequentSubSet) {
//把满足条件的候选项集添加到candidateCollection 集合中
candidateCollection.put(c, 0);
}
}
}
return candidateCollection;
}
//得到频繁1项集
private Map<String, Integer> getItem1FC() {
Map<String, Integer> sItem1FcMap = new HashMap<String, Integer>();
Map<String, Integer> rItem1FcMap = new HashMap<String, Integer>();// 频繁1项集
for (String trans : transList) {
String[] items = trans.split(ITEM_SPLIT);
for (String item : items) {
Integer count = sItem1FcMap.get(item + ITEM_SPLIT);
if (count == null) {
sItem1FcMap.put(item + ITEM_SPLIT, 1);
} else {
sItem1FcMap.put(item + ITEM_SPLIT, count + 1);
}
}
}
Set<String> keySet = sItem1FcMap.keySet();
for (String key : keySet) {
Integer count = sItem1FcMap.get(key);
if (count >= SUPPORT) {
rItem1FcMap.put(key, count);
}
}
return rItem1FcMap;
}
//根据频繁项集集合得到关联规则
public Map<String, Double> getRelationRules(
Map<String, Integer> frequentCollectionMap) {
Map<String, Double> relationRules = new HashMap<String, Double>();
Set<String> keySet = frequentCollectionMap.keySet();
for (String key : keySet) {
double countAll = frequentCollectionMap.get(key);
String[] keyItems = key.split(ITEM_SPLIT);
if (keyItems.length > 1) {
List<String> source = new ArrayList<String>();
Collections.addAll(source, keyItems);
List<Set<String>> result = new ArrayList<Set<String>>();
buildSubSet(source, result);// 获得source的所有非空子集
for (Set<String> itemList : result) {
if (itemList.size() < source.size()) {// 只处理真子集
List<String> otherList = new ArrayList<String>();//记录一个子集的补
for (String sourceItem : source) {
if (!itemList.contains(sourceItem)) {
otherList.add(sourceItem);
}
}
String reasonStr = "";// 规则的前置
String resultStr = "";// 规则的结果
for (String item : itemList) {
reasonStr += item + ITEM_SPLIT;
}
for (String item : otherList) {
resultStr = resultStr + item + ITEM_SPLIT;
}
double countReason = frequentCollectionMap
.get(reasonStr);
double itemConfidence = countAll / countReason;// 计算置信度
//if (itemConfidence >= CONFIDENCE) {
String rule = reasonStr + CON + resultStr;
relationRules.put(rule, itemConfidence);
//}
}
}
}
}
return relationRules;
}
private void buildSubSet(List<String> sourceSet, List<Set<String>> result) {
int n = sourceSet.size();
//n个元素有2^n-1个非空子集
int num = (int) Math.pow(2, n);
for (int i = 1; i < num; i++) {
String binary = Integer.toBinaryString(i);
int size = binary.length();
for (int k = 0; k < n-size; k++) {//将二进制表示字符串右对齐,左边补0
binary = "0"+binary;
}
Set<String> set = new TreeSet<String>();
for (int index = 0; index < sourceSet.size(); index++) {
if(binary.charAt(index) == '1'){
set.add(sourceSet.get(index));
}
}
result.add(set);
}
}
public static void main(String[] args) {
Apriori apriori = new Apriori();
Map<String, Integer> frequentCollectionMap = apriori.getFC();
System.out.println("----------------频繁集" + "----------------");
Set<String> fcKeySet = frequentCollectionMap.keySet();
for (String fcKey : fcKeySet) {
System.out.println(fcKey + " : "
+ frequentCollectionMap.get(fcKey));
}
Map<String, Double> relationRulesMap = apriori
.getRelationRules(frequentCollectionMap);
System.out.println("----------------关联规则" + "----------------");
Set<String> rrKeySet = relationRulesMap.keySet();
for (String rrKey : rrKeySet) {
System.out.println(rrKey + " : " + relationRulesMap.get(rrKey));
}
}
}
运行结果如下:
----------------频繁集----------------
3, : 6
4, : 2
5, : 2
1,5, : 2
1,2, : 4
2, : 7
1,3, : 4
1, : 6
2,5, : 2
1,2,5, : 2
2,3, : 4
1,2,3, : 2
2,4, : 2
----------------关联规则----------------
2,3,-->1, : 0.5
2,-->1,3, : 0.2857142857142857
3,-->2, : 0.6666666666666666
5,-->1,2, : 1.0
2,5,-->1, : 1.0
4,-->2, : 1.0
2,-->1,5, : 0.2857142857142857
5,-->2, : 1.0
2,-->5, : 0.2857142857142857
1,2,-->3, : 0.5
1,-->2,3, : 0.3333333333333333
5,-->1, : 1.0
1,3,-->2, : 0.5
1,2,-->5, : 0.5
1,-->3, : 0.6666666666666666
1,-->2,5, : 0.3333333333333333
1,-->2, : 0.6666666666666666
2,-->4, : 0.2857142857142857
1,-->5, : 0.3333333333333333
2,-->3, : 0.5714285714285714
3,-->1, : 0.6666666666666666
3,-->1,2, : 0.3333333333333333
1,5,-->2, : 1.0
2,-->1, : 0.5714285714285714