一门课程的作业,原理就网上随便找找,文章只提供代码和测试数据。
代码:
package com.outsider.apriori;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* 实现Apriori算法,并用其产生符合最小支持度的频繁项集,
* 再根据频繁项集产生关联规则。
*
* @author outsider
*/
public class Apriori {
private final static int SUPPORT = 4; // 最小支持度
private final static double CONFIDENCE = 0.6; // 最小置信度
private final static String ITEM_SPLIT = ","; // 分隔符
private final static String FILE_ADDRESS = "./data/weather.nominal.txt"; // 分隔符
private final static String CON = "->";
private Set<String> itemList;
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
Apriori apriori = new Apriori();
Map<String, Integer> tempFrequentSetMap = new HashMap<>();
Map<String, Integer> frequentSetMap = new HashMap<>();
Map<String, Double> associationRules = new HashMap<>();
//1.读文件每一行到list中
ArrayList<String[]> listFromFile = apriori.readFile(FILE_ADDRESS);
//2.发现1项集
tempFrequentSetMap = apriori.findFrequentOneSets(listFromFile);
while (!tempFrequentSetMap.isEmpty()) {
frequentSetMap.putAll(tempFrequentSetMap);
tempFrequentSetMap = apriori.getCandidateSetMap(tempFrequentSetMap);
tempFrequentSetMap = apriori.getFrequentSetMap(listFromFile,
tempFrequentSetMap);
}
//统计以下L(1),L(2),L(k)的个数
Map<Integer, Integer> setNums = new HashMap<>();
for(String set : frequentSetMap.keySet()) {
int len = set.split(ITEM_SPLIT).length;
Integer count = setNums.get(len);
if(count == null)
setNums.put(len, 1);
else
setNums.put(len, count+1);
}
System.out.println("各项集个数:");
setNums.forEach((k,v)->System.out.println("L("+k+"):"+v));
//获取关联规则
associationRules = apriori.getAssociationRules(frequentSetMap);
System.out.println("频繁项集:");
List<String> list = new ArrayList<String>(frequentSetMap.keySet());
Collections.sort(list);
for (String string : list) {
System.out.println(string + " : " + frequentSetMap.get(string));
}
System.out.println("关联规则: ");
List<String> associationRulesList = new ArrayList<String>(associationRules.keySet());
// Collections.sort(list);
for (String string : associationRulesList) {
System.out.println(string + " : " + associationRules.get(string));
}
}
/**
* -从文件读取数据,存入List中
*
* @param fileAdd
* @return arrayList
* @throws IOException
*/
private ArrayList<String[]> readFile(String fileAdd) throws IOException {
itemList = new HashSet<>();
ArrayList<String[]> arrayList = new ArrayList<>();
File file = new File(fileAdd);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
String[] cols = tempString.split(ITEM_SPLIT);
arrayList.add(cols);
for(String col : cols) {
itemList.add(col);
}
}
reader.close();
return arrayList;
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return null;
}
/**
* -发现频繁一项集
*
* @param dataList
* @return resultSetMap
*/
private Map<String, Integer> findFrequentOneSets(
ArrayList<String[]> dataList) {
Map<String, Integer> frequentOneSets = new HashMap<>();
for(String item : itemList) {
frequentOneSets.put(item, 0);
}
for(String[] sample : dataList) {
for(String column : sample) {
frequentOneSets.put(column, frequentOneSets.get(column)+1);
}
}
//筛选大于等于最小支持度
Iterator<Entry<String, Integer>> iter = frequentOneSets.entrySet().iterator();
while(iter.hasNext()) {
Entry<String, Integer> entry = iter.next();
if(entry.getValue() < SUPPORT)
iter.remove();
}
return frequentOneSets;
}
/**
* 获取所有候选频繁项集,包含连接步和剪枝步
*
* @param inputMap
* @return candidateSetMap
*/
private Map<String, Integer> getCandidateSetMap(
Map<String, Integer> inputMap) {
//连接上一个k项集产生k+1项集
Set<String> kitems = new HashSet<>();
Map<String, Integer> result = new HashMap<>();
Set<String> keys = inputMap.keySet();
//k项集的k=kSetLen
int kSetLen = keys.iterator().next().split(ITEM_SPLIT).length + 1;
//自身连接,每2个连接成一个组合,重复的去掉
for(String key : keys) {
for(String key2 : keys) {
if(key.equals(key2))
continue;
String[] i1 = key.split(ITEM_SPLIT);
String[] i2 = key2.split(ITEM_SPLIT);
Set<String> i3 = new HashSet<>();
for(String ii1 : i1)
i3.add(ii1);
for(String ii2 : i2)
i3.add(ii2);
if(i3.size() != kSetLen)
continue;
List<String> i3List = new ArrayList<>(i3.size());
for(String ii3 : i3)
i3List.add(ii3);
Collections.sort(i3List);
String[] s = new String[i3List.size()];
i3List.toArray(s);
kitems.add(String.join(ITEM_SPLIT, s));
}
}
for(String s : kitems)
result.put(s, 0);
return result;
}
/**
* -根具候选项集获取满足最小支持度的频繁项集
*
* @param inputList
* @param inputMap
* @return
*/
private Map<String, Integer> getFrequentSetMap(ArrayList<String[]> inputList,
Map<String, Integer> inputMap) {
int flog = 0;
List<String> list = new ArrayList<>();
Set<String> keySet = new HashSet<>();
keySet.addAll(inputMap.keySet());
for (String[] strings : inputList) {
//String[] strings = data.split(ITEM_SPLIT);
for (int i = 0; i < strings.length; i++) {
list.add(strings[i]);
}
for (String string : keySet) {
String[] keyItem = string.split(ITEM_SPLIT);
flog = keyItem.length;
for (String string2 : keyItem) {
if (list.contains(string2)) {
--flog;
}
}
if (flog == 0) {
inputMap.put(string, inputMap.get(string) + 1);
}
}
list.clear();
}
for (String string : keySet) {
if (inputMap.get(string) < SUPPORT) {
inputMap.remove(string);
}
}
return inputMap;
}
/**
* -根据频繁项集获取关联规则
*
* @param frequentSetMap
* @return
*/
private Map<String, Double> getAssociationRules(
Map<String, Integer> frequentSetMap) {
//从k>=2项集开始筛选
//选择一个item作为结果,其余的作为前置条件
Map<String, Double> rules = new HashMap<>();
for(String frequentSet : frequentSetMap.keySet()) {
String[] items = frequentSet.split(ITEM_SPLIT);
if(items.length >= 2) {
for(int i = 0; i < items.length; i++) {
String res = items[i];
String[] carr = new String[items.length -1];
int count = 0;
for(int j = 0; j < items.length; j++) {
if(j!=i) carr[count++] = items[j];
}
//产生关联规则 xx->res
String condition = String.join(ITEM_SPLIT, carr);
int f = frequentSetMap.get(condition);
double confidence = frequentSetMap.get(frequentSet)*1.0 / f;
if(confidence >= CONFIDENCE)
rules.put(condition+CON+res, confidence);
}
}
}
return rules;
}
}
数据:
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no