关联规则挖掘算法之Apriori算法的原理解析及代码实现

1、Apriori算法分为两个步骤:

    1)寻找所有不低于最小支持度的项集 (频繁项集, 又称大项集);

    2)使用频繁项集生成规则

PS:

频繁项集:支持度大于最小支持度的项集。

核心思想: 先验性质(向下封闭性质),即频繁项集的任意子集都是频繁的。

迭代算法(又称逐层搜索算法):  寻找所有1-频繁项集; 然后所有2-频繁项集, 依此类推。

举一个实例来说:

数据集如下:

TID

Items

T1

1, 3, 4

T2

2, 3, 5

T3

1, 2, 3, 5

T4

2, 5

寻找频繁项集,规定最小支持度minsup=0.5,即出现次数>=2:

(生成候选集分为两步:一是连接操作,二是剪枝操作)

k=1,扫描T -> C1: {1}:2, {2}:3, {3}:3, {4}:1, {5}:3   //统计每个元素出现的次数

                    -> F1: {1}:2, {2}:3, {3}:3, {5}:3   //去掉支持度小于minsup的元素

                    -> C2: {1,2}, {1,3}, {1,5}, {2,3}, {2,5}, {3,5}   //从F1中取元素,须符合先验原则

k=2,扫描T -> C2: {1,2}:1, {1,3}:2, {1,5}:1, {2,3}:2, {2,5}:3, {3,5}:2  

                    -> F2: {1,3}:2, {2,3}:2, {2,5}:3, {3,5}:2

                    -> C3: {2,3,5}   //剪枝,{1,2,3}和{1,3,5} 不满足先验原则

k=3,扫描T -> C3: {2,3,5}:2

                    -> F2: {2,3,5}

所以1-频繁项集:{1}, {2}, {3}, {5}

        2-频繁项集:{1,3}, {2,3}, {2,5}, {3,5}

        3-频繁项集:{2,3,5}

最后可通过频繁项集生成规则。

2、利用上述数据集实现寻找频繁项集:

(可能代码写的比较冗余= =见谅)

根据上述分析,首先要知道每一个项集中不可能有重复值,候选项集通过上一轮的频繁项集连接产生,根据先验原则剪枝获得。则可将每一个项集看做一个对象,元素用list存储,且为了方便将list排序。

AprioriNode类代码如下:

public class AprioriNode {
    private List<Integer> list = new ArrayList<>();
    
    public AprioriNode(){
    }
    
    public AprioriNode(String str){
        String[] strs = str.split(",");
        for(String value:strs){
            list.add(Integer.parseInt(value));
        }
        Collections.sort(list);
    }
    
    public void add(String str){
        list.add(Integer.parseInt(str));
        Collections.sort(list);
    }
    
    //加入一个元素
    public boolean addValue(int value){
        if(list.contains(value)){
            return false;
        }else{
            list.add(value);
            Collections.sort(list);
            return true;
        }
    }
    
    public List<Integer> getList() {
        return list;
    }
    
    //复制候选项集
    public AprioriNode getCopyNode(){
        AprioriNode node = new AprioriNode();
        for(int value:list){
            node.addValue(value);
        }
        return node;
    }
    
    public String toString(){
        StringBuilder sb = new StringBuilder();
        if(!list.isEmpty()){
            int i = 0;
            for(; i<list.size()-1 ; i++){
                sb.append(list.get(i)).append(",");
            }
            sb.append(list.get(i));
        }
        return sb.toString();
    }
    
    public boolean equals(AprioriNode o){
        if(this.getList().equals(o.getList())){
            return true;
        }
        return false;
    }
}

因为频繁项集必须满足先验原则,及其所有子集均为频繁项集。所以需要获得候选项集的所有子集,然后验证其所有子集是否在上一轮的频繁项集中存在。所以需要一个工具类,如下:

public class AprioriUtil {
    
    /**
     * 通过候选集得到其所有子集,及对候选集中每个元素遍历,每次只删除一个,即获得。
     * @param node 候选项集
     */
    public static List<AprioriNode> genSubSet(AprioriNode node){
        List<AprioriNode> nodes = new ArrayList<>();
        for(int i=0; i<node.getList().size(); i++){
            //因为不能改动传进来的node,所以需要对候选集复制一份
            AprioriNode tempNode = node.getCopyNode();
            tempNode.getList().remove(i);
            nodes.add(tempNode);
        }
        return nodes;
    }
    
    /**
     * 查找候选集的子集是否在上一轮存在
     */
    public static boolean isExist(AprioriNode node, List<AprioriNode> nodes){
        boolean flag = false;
        for (int i = 0; i < nodes.size(); i++) {
            if(node.equals(nodes.get(i))){  //一定记得重写equals方法,因为Object的equals方法默认是==,比较指向地址是否一致
                flag = true;
            }
        }
        return flag;
    }
}

最后实现寻找出所有频繁项集,代码如下(因为代码中有注释,就不详细介绍了):

 

/**
 * 寻找频繁项集
 * @author ZD
 */
public class AprioriTest {
    private static Map<String, Integer> map = new HashMap<String, Integer>();
    
    private static class AprioriTestMapper extends Mapper<LongWritable, Text, Text, IntWritable>{

        @Override
        protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, IntWritable>.Context context)
                throws IOException, InterruptedException {
            String[] strs = value.toString().split(",");
            for(String str:strs){
                context.write(new Text(str), new IntWritable(1));
            }
        }
    }
    
    private static class AprioriKMapper extends Mapper<LongWritable, Text, Text, IntWritable>{
        private Set<String> set = new HashSet<>();   //上一轮的频繁项集的所有不重复元素
        private List<AprioriNode> preSet = new ArrayList<>();   //上一轮的频繁项集
        private List<AprioriNode> db = new ArrayList<>(); //原始的数据集
        @Override
        protected void setup(Mapper<LongWritable, Text, Text, IntWritable>.Context context)
                throws IOException, InterruptedException {
            FileSystem fs = FileSystem.get(context.getConfiguration());
            int k = context.getConfiguration().getInt("k", 1);
            BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(new Path("/apriori/output"+k+"/part-r-00000"))));
            String line="";
            while((line=br.readLine())!=null){
                String[] strs = line.split("\t");
                preSet.add(new AprioriNode(strs[0]));  
                String[] words = strs[0].split(",");
                if(words.length==1){
                    set.add(words[0]);
                }else{
                    for(String word:words){
                        set.add(word);
                    }
                }
            }
            br.close();
            
            br = new BufferedReader(new InputStreamReader(fs.open(new Path("/input/apriori/apriori.txt"))));
            line="";
            while((line=br.readLine())!=null){
                db.add(new AprioriNode(line));
            }
            br.close();
        }

        @Override
        protected void map(LongWritable key, Text value, Mapper<LongWritable, Text, Text, IntWritable>.Context context)
                throws IOException, InterruptedException {
            String[] strs = value.toString().split("\t");
            for(String data:set){
                AprioriNode node = new AprioriNode(strs[0]);
                node.addValue(Integer.parseInt(data));
                //获得候选集的所有子集
                List<AprioriNode> list = AprioriUtil.genSubSet(node);
                int childCount=0;
                for(AprioriNode node2:list){
                    //判断子集是否在上一轮的频繁项集中存在
                    if(AprioriUtil.isExist(node2, preSet)){
                        childCount++;
                    }
                    System.out.println("childCount:"+childCount+", size:"+list.size());
                    //所有子集均为频繁项集的项集才是真正的候选集,相当于剪枝
                    if(childCount==list.size()){
                        //获得候选集在数据集中出现的次数
                        int count=0;
                        for(AprioriNode dbNode:db){
                            int tempCount=0;
                            System.out.println("dbNode:"+dbNode.toString()+", node:"+node.toString());
                            String[] temps = node.toString().split(",");
                            if(temps.length>1){
                                for(String temp:temps){
                                    if(dbNode.toString().contains(temp)){
                                        tempCount++;
                                    }
                                }
                            }
                            if(tempCount==temps.length){
                                count++;
                            }
                        }
                        //将候选项集写出
                        context.write(new Text(node.toString()), new IntWritable(count));
                    }
                }
            }
        }
    }
    
    private static class AprioriTestReducer extends Reducer<Text, IntWritable, Text, IntWritable>{
        
        @Override
        protected void reduce(Text value, Iterable<IntWritable> datas,
                Reducer<Text, IntWritable, Text, IntWritable>.Context context) throws IOException, InterruptedException {
            int sum=0;
            for(IntWritable data:datas){
                sum+=data.get();
            }
            if(sum>=2){
                context.write(value, new IntWritable(sum));
            }
        }
    }
    
    private static class AprioriKReducer extends Reducer<Text, IntWritable, Text, IntWritable>{
        
        @Override
        protected void reduce(Text value, Iterable<IntWritable> datas,
                Reducer<Text, IntWritable, Text, IntWritable>.Context context) throws IOException, InterruptedException {
            for(IntWritable data:datas){
                if(!map.containsKey(value.toString())){
                    map.put(value.toString(), data.get());
                }
            }
        }

        @Override
        protected void cleanup(Reducer<Text, IntWritable, Text, IntWritable>.Context context)
                throws IOException, InterruptedException {
            for(String key:map.keySet()){
                if(map.get(key)>=2){
                    context.write(new Text(key), new IntWritable(map.get(key)));
                }
            }
            map.clear();   //每次job后,清除map内容
        }
    }

    public static void main(String[] args) {
        try {
            //处理初始数据集,获得1-频繁项集
            Configuration cfg = HadoopCfg.getConfigration();
            Job job = Job.getInstance(cfg);
            job.setJobName("AprioriTest");
            job.setJarByClass(AprioriTest.class);
            job.setMapperClass(AprioriTestMapper.class);
            job.setMapOutputKeyClass(Text.class);
            job.setMapOutputValueClass(IntWritable.class);
            job.setReducerClass(AprioriTestReducer.class);
            job.setOutputKeyClass(Text.class);
            job.setOutputValueClass(IntWritable.class);
            FileInputFormat.addInputPath(job, new Path("/input/apriori/"));
            FileOutputFormat.setOutputPath(job, new Path("/apriori/output1/"));
            job.waitForCompletion(true);
            
            for (int k=1; k < 3; k++) {
                cfg.setInt("k", k);
                job = Job.getInstance(cfg);
                job.setJobName("AprioriTest");
                job.setJarByClass(AprioriTest.class);
                job.setMapperClass(AprioriKMapper.class);
                job.setMapOutputKeyClass(Text.class);
                job.setMapOutputValueClass(IntWritable.class);
                job.setReducerClass(AprioriKReducer.class);
                job.setOutputKeyClass(Text.class);
                job.setOutputValueClass(IntWritable.class);
                FileInputFormat.addInputPath(job, new Path("/apriori/output"+k+"/"));
                FileOutputFormat.setOutputPath(job, new Path("/apriori/output"+(k+1)+"/"));
                job.waitForCompletion(true);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

以上代码还可以精简,如可以只用一个Reducer(只保留AprioriTestReducer 且内容不改变), 大家可以试试。

最后利用频繁项集生成规则。

写在最后:若有错误,望大家指出纠正,谢谢。送一句话给自己,人不能为自己的不作为而找借口。下次将与大家分享利用Rsync算法实现简单云盘的文件上传。

转载于:https://my.oschina.net/eager/blog/683781

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值