java实现关联规则算法--Apriori

一些细节

  1. 数据库为什么使用List< Set< String > >结构?
    因为在从候选k-项集到频繁k-项集的时候要扫描数据库,计算k-项集出现的次数,即计算支持度计数(需要多次扫描数据库)。如果是集合,就可以在O(k)时间判断出k-项集是否出现在某个人的购物清单中,但是如果使用列表,就需要O(kn),n是列表的长度。
    代码:

    //计算支持度
    public double support(List<String> kSet, List<Set<String>> dataBase){
        int count = 0;
        for (Set<String> set: dataBase
             ) {
            if (set.containsAll(kSet)) count++;
        }
        return (double) count / dataBase.size();
    }
    
  2. 为什么只查看k-项集的k个k-1项子集是否是频繁集?
    网上的博客帖子都说的是要k-项集的所有非空真子集都要是频繁集才行,为什么我这里只检测它的k个k-1项子集呢。这是因为它的其他的子集都包含在这k个k-1项子集中了。所以只要k个k-1项子集都是频繁的,那么它的其他的子集也必然是频繁的。

  3. 为什么剪枝的时候检测k-项集的子集是否频繁集进行筛选,扫描数据库计算支持度筛选?
    这里强调的是顺序的先后。因为数据库很大,扫描一遍数据库是很费时间的,而检测子集则相对来说耗费的时间要小很多。因此先检测子集,缩小后面扫描数据库的时间是可以更好地提升性能的。

实现

package com.ftq.demo.highlevel;

import java.io.*;
import java.util.*;

public class MyAprioriDemo {
    private double min_sup;//支持度阈值
    private double min_con;//置信度阈值

    //构造函数
    MyAprioriDemo(double min_sup, double min_con){
        this.min_sup = min_sup;
        this.min_con = min_con;
    }

    public void setMin_sup(double sup){
        this.min_sup = sup;
    }
    public void setMin_con(double con){
        this.min_con = con;
    }

    //获取数据
    public List<Set<String>> getData(String path){
        List<Set<String>> dataSet = new ArrayList<>();
        try {
            File fin = new File(path);
            FileInputStream finS = new FileInputStream(fin);
            InputStreamReader isr = new InputStreamReader(finS);
            BufferedReader reader = new BufferedReader(isr);
            String line = reader.readLine();
            while (line != null){
//                System.out.println(line);
                String[] lin = line.split("\\s");
//                System.out.println(Arrays.toString(lin));
                dataSet.add(toSet(lin));
                line = reader.readLine();
            }
            reader.close();
            isr.close();
            finS.close();
        }catch (IOException e){
            System.out.println(e);
        }

        return dataSet;
    }

    //显示数据
    public void showDataSet(List<Set<String>> dataSet){
        for (Set<String> set: dataSet
             ) {
            for (String ele: set
                 ) {
                System.out.print(ele+'\t');
            }
            System.out.println();
        }
    }

    //显示频繁集
    public void showFreqSet(Map<List<String>, Double> freq){
        for (Map.Entry<List<String>, Double> entry: freq.entrySet()
             ) {
            System.out.println(entry.getKey().toString() + ": " + entry.getValue());
        }
    }
    //将array转变为一个set,其中array第一个元素是行号,最后一个元素是发散,都不需要
    private Set<String> toSet(String[] array){
        int len = array.length;
        Set<String> set = new HashSet<>();
        for (int i = 1; i < len - 1; i++){
            set.add("P" + i + "_" + array[i]);
        }
        return set;
    }

    //生成频繁一项集
    public Map<List<String>, Double> generateOneItem(List<Set<String>> dataBase){
        Map<List<String>, Double> frequen = new HashMap<>();
        for (Set<String> line: dataBase
             ) {
            for (String item: line
                 ) {
                //统计每一项出现的次数
                List<String> list = generaList(item);
                frequen.put(list, frequen.getOrDefault(list, 0.0) + 1.0);
            }
        }
        int size = dataBase.size();
        for (List<String> key: frequen.keySet()
             ) {
            //计算支持度
            double sup = frequen.get(key) / size;
            if (sup > min_sup){
                frequen.put(key, sup);
            }else {//如果小于阈值,直接删除
                frequen.remove(key);
            }
        }
        return frequen;
    }

    //生成一个单元素list
    private List<String> generaList(String element){
        List<String> result = new ArrayList<>();
        result.add(element);
        return result;
    }


    //判断两个项集能否进行连接
    public boolean ableJoin(List<String> set1, List<String> set2){
        int len1 = set1.size();
        int len2 = set2.size();
        if (len1 != len2) return false;
        //能进行连接的条件是前k-1项必须相同,最后一项不同
        for (int i = 0; i < len1 - 1; i++) {
//            前面的项必须相同
            if (!set1.get(i).equals(set2.get(i))) return false;
        }
        //最后一项不同
        return !set1.get(len1-1).equals(set2.get(len2-1));
    }

    //两个能连接的项集进行连接
    public List<String> join(List<String> set1, List<String> set2){
        int len = set1.size();
        List<String> result = new ArrayList<>();
        for (String ele: set1
             ) {
            result.add(ele);
        }
        if (set1.get(len-1).compareTo(set2.get(len - 1)) < 0){//如果set1的最后一项小,就将set2最后一项添在最后
            result.add(set2.get(len - 1));
        }else {//否则插入set1最后一项前面
            result.add(len-1, set2.get(len - 1));
        }
        return result;
    }

    //判断set的k个k-1项子集是否都是频繁集,
    public boolean isRetain(List<String> set, Map<List<String>, Double> frequenSet){
        int k = set.size();
        //逐个检查k-1项子集是否存在于频繁集中
        for (int i = k-1; i > -1; i--) {
            List<String> sub = new ArrayList<>();
            //生成缺i的k-1项子集
            for (int j = 0; j < k; j++) {
                if(j != i){
                    sub.add(set.get(j));
                }
            }
            if (!frequenSet.containsKey(sub)) return false;
            sub.clear();
        }
        return true;
    }

    //计算支持度
    public double support(List<String> kSet, List<Set<String>> dataBase){
        int count = 0;
        for (Set<String> set: dataBase
             ) {
            if (set.containsAll(kSet)) count++;
        }
        return (double) count / dataBase.size();
    }

    //生成List的[i:j)子列表
    private List<String> generateList(List<String> list, int i, int j){
        List<String> subList = new ArrayList<>();
        for (int k = i; k < j; k++) {
            subList.add(list.get(k));
        }
        return subList;
    }
    //生成一个空列表
    private List<List<String>> generateList(){
        return new ArrayList<>();
    }

    //生成关联规则
    public Map<List<List<String>>, Double> generateRule(Map<List<String>, Double> freqSet){
        Map<List<List<String>>, Double> rule = new HashMap<>();
        for (Map.Entry<List<String>, Double> entry: freqSet.entrySet()
             ) {
            List<String> key = entry.getKey();
            double value = entry.getValue();
            int len = key.size();
            for (int i = 1; i < len; i++) {
                //生成规则前项
                List<String> preRule = generateList(key, 0, i);
                //前项的概率
                double p = freqSet.get(preRule);
                //计算规则的置信度
                double v = value / p;
                if (v > min_con){
                    List<String> proRule = generateList(key, i, len);
        
                    List<List<String>> k = new ArrayList<>();
                    k.add(preRule);
                    k.add(proRule);
                    rule.put(k, v);
                }
            }
        }
        return rule;
    }



    public static void main(String[] args) {
//      new一个对象
        MyAprioriDemo apri = new MyAprioriDemo(0.4, 0.2);
//      读取数据库数据到内存
        List<Set<String>> dataBase = apri.getData("test_1000.dat");
//        apri.showDataSet(data);
//        生成一项频繁集
        Map<List<String>, Double> frequenSet = apri.generateOneItem(dataBase);
//        apri.showFreqSet(frequenSet);
//        k项集
        List<List<String>> kItem = new ArrayList<>();
        for (Map.Entry<List<String>, Double> ele: frequenSet.entrySet()
             ) {
            kItem.add(ele.getKey());
        }
        Map<List<String>, Double> kFreqSet = new HashMap<>();
        //k项集必须不为空,并且,至少k个k-1项集才可以保证能够生成一个k项候选集
        //循环由k-1项集生成k项集
        while (kItem.size() > 0 && kItem.get(0).size() < kItem.size()){
            for (int i = 0; i < kItem.size(); i++) {
                List<String> set1 = kItem.get(i);
//                System.out.println(set1.toString());
                for (int j = i+1; j < kItem.size(); j++) {
                    List<String> set2 = kItem.get(j);
//                    System.out.println(set2.toString());

                    //握手法遍历所有的k-1项集对儿
                    //是否能连接
                    if (apri.ableJoin(set1, set2)){
//                        System.out.println("true");
                        List<String> jon = apri.join(set1, set2);
                        //所有的k-1项集是否都是频繁的,如果有不频繁的子集,就不保留
                        if (apri.isRetain(jon, frequenSet)){
//                            生成的k项候选集是否存在数据库中
                            //为什么要先查频繁集后查数据库呢,这是因为频繁集教小而数据库很大,先查
//                            频繁集就可以缩小搜索的规模
                            //计算k项集的支持度
                            Double sup = apri.support(jon, dataBase);
//                          大于阈值,就是k项频繁集,保存
                            if (sup > apri.min_sup){
                                kFreqSet.put(jon, sup);
                            }
                        }
                    }
                }
            }
            kItem.clear();
            //将k项频繁集加入频繁集中
            for (Map.Entry<List<String>, Double> ele: kFreqSet.entrySet()
                 ) {
                kItem.add(ele.getKey());
                frequenSet.put(ele.getKey(), ele.getValue());
            }
            kFreqSet.clear();
        }
//        apri.showFreqSet(frequenSet);
        Map<List<List<String>>, Double> rule = apri.generateRule(frequenSet);
        for (Map.Entry<List<List<String>>, Double> ent: rule.entrySet()
             ) {
            System.out.println(ent.getKey().toString());
            System.out.println(ent.getValue());
        }
    }
}

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值