一些细节
-
数据库为什么使用List< Set< String > >结构?
因为在从候选k-项集到频繁k-项集的时候要扫描数据库,计算k-项集出现的次数,即计算支持度计数(需要多次扫描数据库)。如果是集合,就可以在O(k)时间判断出k-项集是否出现在某个人的购物清单中,但是如果使用列表,就需要O(kn),n是列表的长度。
代码://计算支持度 public double support(List<String> kSet, List<Set<String>> dataBase){ int count = 0; for (Set<String> set: dataBase ) { if (set.containsAll(kSet)) count++; } return (double) count / dataBase.size(); }
-
为什么只查看k-项集的k个k-1项子集是否是频繁集?
网上的博客帖子都说的是要k-项集的所有非空真子集都要是频繁集才行,为什么我这里只检测它的k个k-1项子集呢。这是因为它的其他的子集都包含在这k个k-1项子集中了。所以只要k个k-1项子集都是频繁的,那么它的其他的子集也必然是频繁的。 -
为什么剪枝的时候先检测k-项集的子集是否频繁集进行筛选,后扫描数据库计算支持度筛选?
这里强调的是顺序的先后。因为数据库很大,扫描一遍数据库是很费时间的,而检测子集则相对来说耗费的时间要小很多。因此先检测子集,缩小后面扫描数据库的时间是可以更好地提升性能的。
实现
package com.ftq.demo.highlevel;
import java.io.*;
import java.util.*;
public class MyAprioriDemo {
private double min_sup;//支持度阈值
private double min_con;//置信度阈值
//构造函数
MyAprioriDemo(double min_sup, double min_con){
this.min_sup = min_sup;
this.min_con = min_con;
}
public void setMin_sup(double sup){
this.min_sup = sup;
}
public void setMin_con(double con){
this.min_con = con;
}
//获取数据
public List<Set<String>> getData(String path){
List<Set<String>> dataSet = new ArrayList<>();
try {
File fin = new File(path);
FileInputStream finS = new FileInputStream(fin);
InputStreamReader isr = new InputStreamReader(finS);
BufferedReader reader = new BufferedReader(isr);
String line = reader.readLine();
while (line != null){
// System.out.println(line);
String[] lin = line.split("\\s");
// System.out.println(Arrays.toString(lin));
dataSet.add(toSet(lin));
line = reader.readLine();
}
reader.close();
isr.close();
finS.close();
}catch (IOException e){
System.out.println(e);
}
return dataSet;
}
//显示数据
public void showDataSet(List<Set<String>> dataSet){
for (Set<String> set: dataSet
) {
for (String ele: set
) {
System.out.print(ele+'\t');
}
System.out.println();
}
}
//显示频繁集
public void showFreqSet(Map<List<String>, Double> freq){
for (Map.Entry<List<String>, Double> entry: freq.entrySet()
) {
System.out.println(entry.getKey().toString() + ": " + entry.getValue());
}
}
//将array转变为一个set,其中array第一个元素是行号,最后一个元素是发散,都不需要
private Set<String> toSet(String[] array){
int len = array.length;
Set<String> set = new HashSet<>();
for (int i = 1; i < len - 1; i++){
set.add("P" + i + "_" + array[i]);
}
return set;
}
//生成频繁一项集
public Map<List<String>, Double> generateOneItem(List<Set<String>> dataBase){
Map<List<String>, Double> frequen = new HashMap<>();
for (Set<String> line: dataBase
) {
for (String item: line
) {
//统计每一项出现的次数
List<String> list = generaList(item);
frequen.put(list, frequen.getOrDefault(list, 0.0) + 1.0);
}
}
int size = dataBase.size();
for (List<String> key: frequen.keySet()
) {
//计算支持度
double sup = frequen.get(key) / size;
if (sup > min_sup){
frequen.put(key, sup);
}else {//如果小于阈值,直接删除
frequen.remove(key);
}
}
return frequen;
}
//生成一个单元素list
private List<String> generaList(String element){
List<String> result = new ArrayList<>();
result.add(element);
return result;
}
//判断两个项集能否进行连接
public boolean ableJoin(List<String> set1, List<String> set2){
int len1 = set1.size();
int len2 = set2.size();
if (len1 != len2) return false;
//能进行连接的条件是前k-1项必须相同,最后一项不同
for (int i = 0; i < len1 - 1; i++) {
// 前面的项必须相同
if (!set1.get(i).equals(set2.get(i))) return false;
}
//最后一项不同
return !set1.get(len1-1).equals(set2.get(len2-1));
}
//两个能连接的项集进行连接
public List<String> join(List<String> set1, List<String> set2){
int len = set1.size();
List<String> result = new ArrayList<>();
for (String ele: set1
) {
result.add(ele);
}
if (set1.get(len-1).compareTo(set2.get(len - 1)) < 0){//如果set1的最后一项小,就将set2最后一项添在最后
result.add(set2.get(len - 1));
}else {//否则插入set1最后一项前面
result.add(len-1, set2.get(len - 1));
}
return result;
}
//判断set的k个k-1项子集是否都是频繁集,
public boolean isRetain(List<String> set, Map<List<String>, Double> frequenSet){
int k = set.size();
//逐个检查k-1项子集是否存在于频繁集中
for (int i = k-1; i > -1; i--) {
List<String> sub = new ArrayList<>();
//生成缺i的k-1项子集
for (int j = 0; j < k; j++) {
if(j != i){
sub.add(set.get(j));
}
}
if (!frequenSet.containsKey(sub)) return false;
sub.clear();
}
return true;
}
//计算支持度
public double support(List<String> kSet, List<Set<String>> dataBase){
int count = 0;
for (Set<String> set: dataBase
) {
if (set.containsAll(kSet)) count++;
}
return (double) count / dataBase.size();
}
//生成List的[i:j)子列表
private List<String> generateList(List<String> list, int i, int j){
List<String> subList = new ArrayList<>();
for (int k = i; k < j; k++) {
subList.add(list.get(k));
}
return subList;
}
//生成一个空列表
private List<List<String>> generateList(){
return new ArrayList<>();
}
//生成关联规则
public Map<List<List<String>>, Double> generateRule(Map<List<String>, Double> freqSet){
Map<List<List<String>>, Double> rule = new HashMap<>();
for (Map.Entry<List<String>, Double> entry: freqSet.entrySet()
) {
List<String> key = entry.getKey();
double value = entry.getValue();
int len = key.size();
for (int i = 1; i < len; i++) {
//生成规则前项
List<String> preRule = generateList(key, 0, i);
//前项的概率
double p = freqSet.get(preRule);
//计算规则的置信度
double v = value / p;
if (v > min_con){
List<String> proRule = generateList(key, i, len);
List<List<String>> k = new ArrayList<>();
k.add(preRule);
k.add(proRule);
rule.put(k, v);
}
}
}
return rule;
}
public static void main(String[] args) {
// new一个对象
MyAprioriDemo apri = new MyAprioriDemo(0.4, 0.2);
// 读取数据库数据到内存
List<Set<String>> dataBase = apri.getData("test_1000.dat");
// apri.showDataSet(data);
// 生成一项频繁集
Map<List<String>, Double> frequenSet = apri.generateOneItem(dataBase);
// apri.showFreqSet(frequenSet);
// k项集
List<List<String>> kItem = new ArrayList<>();
for (Map.Entry<List<String>, Double> ele: frequenSet.entrySet()
) {
kItem.add(ele.getKey());
}
Map<List<String>, Double> kFreqSet = new HashMap<>();
//k项集必须不为空,并且,至少k个k-1项集才可以保证能够生成一个k项候选集
//循环由k-1项集生成k项集
while (kItem.size() > 0 && kItem.get(0).size() < kItem.size()){
for (int i = 0; i < kItem.size(); i++) {
List<String> set1 = kItem.get(i);
// System.out.println(set1.toString());
for (int j = i+1; j < kItem.size(); j++) {
List<String> set2 = kItem.get(j);
// System.out.println(set2.toString());
//握手法遍历所有的k-1项集对儿
//是否能连接
if (apri.ableJoin(set1, set2)){
// System.out.println("true");
List<String> jon = apri.join(set1, set2);
//所有的k-1项集是否都是频繁的,如果有不频繁的子集,就不保留
if (apri.isRetain(jon, frequenSet)){
// 生成的k项候选集是否存在数据库中
//为什么要先查频繁集后查数据库呢,这是因为频繁集教小而数据库很大,先查
// 频繁集就可以缩小搜索的规模
//计算k项集的支持度
Double sup = apri.support(jon, dataBase);
// 大于阈值,就是k项频繁集,保存
if (sup > apri.min_sup){
kFreqSet.put(jon, sup);
}
}
}
}
}
kItem.clear();
//将k项频繁集加入频繁集中
for (Map.Entry<List<String>, Double> ele: kFreqSet.entrySet()
) {
kItem.add(ele.getKey());
frequenSet.put(ele.getKey(), ele.getValue());
}
kFreqSet.clear();
}
// apri.showFreqSet(frequenSet);
Map<List<List<String>>, Double> rule = apri.generateRule(frequenSet);
for (Map.Entry<List<List<String>>, Double> ent: rule.entrySet()
) {
System.out.println(ent.getKey().toString());
System.out.println(ent.getValue());
}
}
}