FP-Tree算法的实现

在关联规则挖掘领域最经典的算法法是Apriori,其致命的缺点是需要多次扫描事务数据库。于是人们提出了各种裁剪(prune)数据集的方法以减少I/O开支,韩嘉炜老师的FP-Tree算法就是其中非常高效的一种。

支持度和置信度

严格地说Apriori和FP-Tree都是寻找频繁项集的算法,频繁项集就是所谓的“支持度”比较高的项集,下面解释一下支持度和置信度的概念。

设事务数据库为:

复制代码
A  E  F  G

A  F  G

A  B  E  F  G

E  F  G
复制代码

则{A,F,G}的支持度数为3,支持度为3/4。

{F,G}的支持度数为4,支持度为4/4。

{A}的支持度数为3,支持度为3/4。

{F,G}=>{A}的置信度为:{A,F,G}的支持度数 除以 {F,G}的支持度数,即3/4

{A}=>{F,G}的置信度为:{A,F,G}的支持度数 除以 {A}的支持度数,即3/3

强关联规则挖掘是在满足一定支持度的情况下寻找置信度达到阈值的所有模式。

FP-Tree算法

我们举个例子来详细讲解FP-Tree算法的完整实现。

事务数据库如下,一行表示一条购物记录:

复制代码
牛奶,鸡蛋,面包,薯片

鸡蛋,爆米花,薯片,啤酒

鸡蛋,面包,薯片

牛奶,鸡蛋,面包,爆米花,薯片,啤酒

牛奶,面包,啤酒

鸡蛋,面包,啤酒

牛奶,面包,薯片

牛奶,鸡蛋,面包,黄油,薯片

牛奶,鸡蛋,黄油,薯片
复制代码

我们的目的是要找出哪些商品总是相伴出现的,比如人们买薯片的时候通常也会买鸡蛋,则[薯片,鸡蛋]就是一条频繁模式(frequent pattern)。

FP-Tree算法第一步:扫描事务数据库,每项商品按频数递减排序,并删除频数小于最小支持度MinSup的商品。(第一次扫描数据库)

薯片:7鸡蛋:7面包:7牛奶:6啤酒:4                       (这里我们令MinSup=3)

以上结果就是频繁1项集,记为F1。

第二步:对于每一条购买记录,按照F1中的顺序重新排序。(第二次也是最后一次扫描数据库)

复制代码
薯片,鸡蛋,面包,牛奶

薯片,鸡蛋,啤酒

薯片,鸡蛋,面包

薯片,鸡蛋,面包,牛奶,啤酒

面包,牛奶,啤酒

鸡蛋,面包,啤酒

薯片,面包,牛奶

薯片,鸡蛋,面包,牛奶

薯片,鸡蛋,牛奶
复制代码

第三步:把第二步得到的各条记录插入到FP-Tree中。刚开始时后缀模式为空。

插入每一条(薯片,鸡蛋,面包,牛奶)之后

插入第二条记录(薯片,鸡蛋,啤酒)

插入第三条记录(面包,牛奶,啤酒)

估计你也知道怎么插了,最终生成的FP-Tree是:

上图中左边的那一叫做表头项,树中相同名称的节点要链接起来,链表的第一个元素就是表头项里的元素。

如果FP-Tree为空(只含一个虚的root节点),则FP-Growth函数返回。

此时输出表头项的每一项+postModel,支持度为表头项中对应项的计数。

第四步:从FP-Tree中找出频繁项。

遍历表头项中的每一项(我们拿“牛奶:6”为例),对于各项都执行以下(1)到(5)的操作:

(1)从FP-Tree中找到所有的“牛奶”节点,向上遍历它的祖先节点,得到4条路径:

复制代码
薯片:7,鸡蛋:6,牛奶:1

薯片:7,鸡蛋:6,面包:4,牛奶:3

薯片:7,面包:1,牛奶:1

面包:1,牛奶:1
复制代码

对于每一条路径上的节点,其count都设置为牛奶的count

复制代码
薯片:1,鸡蛋:1,牛奶:1

薯片:3,鸡蛋:3,面包:3,牛奶:3

薯片:1,面包:1,牛奶:1

面包:1,牛奶:1
复制代码

因为每一项末尾都是牛奶,可以把牛奶去掉,得到条件模式基(Conditional Pattern Base,CPB),此时的后缀模式是:(牛奶)。

复制代码
薯片:1,鸡蛋:1

薯片:3,鸡蛋:3,面包:3

薯片:1,面包:1

面包:1
复制代码

(2)我们把上面的结果当作原始的事务数据库,返回到第3步,递归迭代运行。

没讲清楚,你可以参考这篇博客,直接看核心代码吧:

复制代码
public void FPGrowth(List<List<String>> transRecords,
        List<String> postPattern,Context context) throws IOException, InterruptedException {
    // 构建项头表,同时也是频繁1项集
    ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
    // 构建FP-Tree
    TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
    // 如果FP-Tree为空则返回
    if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)
        return;
    //输出项头表的每一项+postPattern
    if(postPattern!=null){
        for (TreeNode header : HeaderTable) {
            String outStr=header.getName();
            int count=header.getCount();
            for (String ele : postPattern)
                outStr+="\t" + ele;
            context.write(new IntWritable(count), new Text(outStr));
        }
    }
    // 找到项头表的每一项的条件模式基,进入递归迭代
    for (TreeNode header : HeaderTable) {
        // 后缀模式增加一项
        List<String> newPostPattern = new LinkedList<String>();
        newPostPattern.add(header.getName());
        if (postPattern != null)
            newPostPattern.addAll(postPattern);
        // 寻找header的条件模式基CPB,放入newTransRecords中
        List<List<String>> newTransRecords = new LinkedList<List<String>>();
        TreeNode backnode = header.getNextHomonym();
        while (backnode != null) {
            int counter = backnode.getCount();
            List<String> prenodes = new ArrayList<String>();
            TreeNode parent = backnode;
            // 遍历backnode的祖先节点,放到prenodes中
            while ((parent = parent.getParent()).getName() != null) {
                prenodes.add(parent.getName());
            }
            while (counter-- > 0) {
                newTransRecords.add(prenodes);
            }
            backnode = backnode.getNextHomonym();
        }
        // 递归迭代
        FPGrowth(newTransRecords, newPostPattern,context);
    }
}
复制代码

对于FP-Tree已经是单枝的情况,就没有必要再递归调用FPGrowth了,直接输出整条路径上所有节点的各种组合+postModel就可了。例如当FP-Tree为:

我们直接输出:

3  A+postModel

3  B+postModel

3  A+B+postModel

就可以了。

如何按照上面代码里的做法,是先输出:

3  A+postModel

3  B+postModel

然后把B插入到postModel的头部,重新建立一个FP-Tree,这时Tree中只含A,于是输出

3  A+(B+postModel)

两种方法结果是一样的,但毕竟重新建立FP-Tree计算量大些。

Java实现

FP树节点定义

?
package  fptree;
  
import  java.util.ArrayList;
import  java.util.List;
  
public  class  TreeNode implements  Comparable<TreeNode> {
  
     private  String name; // 节点名称
     private  int  count; // 计数
     private  TreeNode parent; // 父节点
     private  List<TreeNode> children; // 子节点
     private  TreeNode nextHomonym; // 下一个同名节点
  
     public  TreeNode() {
  
     }
  
     public  TreeNode(String name) {
         this .name = name;
     }
  
     public  String getName() {
         return  name;
     }
  
     public  void  setName(String name) {
         this .name = name;
     }
  
     public  int  getCount() {
         return  count;
     }
  
     public  void  setCount( int  count) {
         this .count = count;
     }
  
     public  TreeNode getParent() {
         return  parent;
     }
  
     public  void  setParent(TreeNode parent) {
         this .parent = parent;
     }
  
     public  List<TreeNode> getChildren() {
         return  children;
     }
  
     public  void  addChild(TreeNode child) {
         if  ( this .getChildren() == null ) {
             List<TreeNode> list = new  ArrayList<TreeNode>();
             list.add(child);
             this .setChildren(list);
         } else  {
             this .getChildren().add(child);
         }
     }
  
     public  TreeNode findChild(String name) {
         List<TreeNode> children = this .getChildren();
         if  (children != null ) {
             for  (TreeNode child : children) {
                 if  (child.getName().equals(name)) {
                     return  child;
                 }
             }
         }
         return  null ;
     }
  
     public  void  setChildren(List<TreeNode> children) {
         this .children = children;
     }
  
     public  void  printChildrenName() {
         List<TreeNode> children = this .getChildren();
         if  (children != null ) {
             for  (TreeNode child : children) {
                 System.out.print(child.getName() + " " );
             }
         } else  {
             System.out.print( "null" );
         }
     }
  
     public  TreeNode getNextHomonym() {
         return  nextHomonym;
     }
  
     public  void  setNextHomonym(TreeNode nextHomonym) {
         this .nextHomonym = nextHomonym;
     }
  
     public  void  countIncrement( int  n) {
         this .count += n;
     }
  
     @Override
     public  int  compareTo(TreeNode arg0) {
         // TODO Auto-generated method stub
         int  count0 = arg0.getCount();
         // 跟默认的比较大小相反,导致调用Arrays.sort()时是按降序排列
         return  count0 - this .count;
     }
}

挖掘频繁模式

?
package  fptree;
 
import  java.io.BufferedReader;
import  java.io.FileReader;
import  java.io.IOException;
import  java.util.ArrayList;
import  java.util.Collections;
import  java.util.Comparator;
import  java.util.HashMap;
import  java.util.LinkedList;
import  java.util.List;
import  java.util.Map;
import  java.util.Map.Entry;
import  java.util.Set;
 
public  class  FPTree {
 
     private  int  minSuport;
 
     public  int  getMinSuport() {
         return  minSuport;
     }
 
     public  void  setMinSuport( int  minSuport) {
         this .minSuport = minSuport;
     }
 
     // 从若干个文件中读入Transaction Record
     public  List<List<String>> readTransRocords(String... filenames) {
         List<List<String>> transaction = null ;
         if  (filenames.length > 0 ) {
             transaction = new  LinkedList<List<String>>();
             for  (String filename : filenames) {
                 try  {
                     FileReader fr = new  FileReader(filename);
                     BufferedReader br = new  BufferedReader(fr);
                     try  {
                         String line;
                         List<String> record;
                         while  ((line = br.readLine()) != null ) {
                             if (line.trim().length()> 0 ){
                                 String str[] = line.split( "," );
                                 record = new  LinkedList<String>();
                                 for  (String w : str)
                                     record.add(w);
                                 transaction.add(record);
                             }
                         }
                     } finally  {
                         br.close();
                     }
                 } catch  (IOException ex) {
                     System.out.println( "Read transaction records failed."
                             + ex.getMessage());
                     System.exit( 1 );
                 }
             }
         }
         return  transaction;
     }
 
     // FP-Growth算法
     public  void  FPGrowth(List<List<String>> transRecords,
             List<String> postPattern) {
         // 构建项头表,同时也是频繁1项集
         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
         // 构建FP-Tree
         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
         // 如果FP-Tree为空则返回
         if  (treeRoot.getChildren()== null  || treeRoot.getChildren().size() == 0 )
             return ;
         //输出项头表的每一项+postPattern
         if (postPattern!= null ){
             for  (TreeNode header : HeaderTable) {
                 System.out.print(header.getCount() + "\t"  + header.getName());
                 for  (String ele : postPattern)
                     System.out.print( "\t"  + ele);
                 System.out.println();
             }
         }
         // 找到项头表的每一项的条件模式基,进入递归迭代
         for  (TreeNode header : HeaderTable) {
             // 后缀模式增加一项
             List<String> newPostPattern = new  LinkedList<String>();
             newPostPattern.add(header.getName());
             if  (postPattern != null )
                 newPostPattern.addAll(postPattern);
             // 寻找header的条件模式基CPB,放入newTransRecords中
             List<List<String>> newTransRecords = new  LinkedList<List<String>>();
             TreeNode backnode = header.getNextHomonym();
             while  (backnode != null ) {
                 int  counter = backnode.getCount();
                 List<String> prenodes = new  ArrayList<String>();
                 TreeNode parent = backnode;
                 // 遍历backnode的祖先节点,放到prenodes中
                 while  ((parent = parent.getParent()).getName() != null ) {
                     prenodes.add(parent.getName());
                 }
                 while  (counter-- > 0 ) {
                     newTransRecords.add(prenodes);
                 }
                 backnode = backnode.getNextHomonym();
             }
             // 递归迭代
             FPGrowth(newTransRecords, newPostPattern);
         }
     }
 
     // 构建项头表,同时也是频繁1项集
     public  ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
         ArrayList<TreeNode> F1 = null ;
         if  (transRecords.size() > 0 ) {
             F1 = new  ArrayList<TreeNode>();
             Map<String, TreeNode> map = new  HashMap<String, TreeNode>();
             // 计算事务数据库中各项的支持度
             for  (List<String> record : transRecords) {
                 for  (String item : record) {
                     if  (!map.keySet().contains(item)) {
                         TreeNode node = new  TreeNode(item);
                         node.setCount( 1 );
                         map.put(item, node);
                     } else  {
                         map.get(item).countIncrement( 1 );
                     }
                 }
             }
             // 把支持度大于(或等于)minSup的项加入到F1中
             Set<String> names = map.keySet();
             for  (String name : names) {
                 TreeNode tnode = map.get(name);
                 if  (tnode.getCount() >= minSuport) {
                     F1.add(tnode);
                 }
             }
             Collections.sort(F1);
             return  F1;
         } else  {
             return  null ;
         }
     }
 
     // 构建FP-Tree
     public  TreeNode buildFPTree(List<List<String>> transRecords,
             ArrayList<TreeNode> F1) {
         TreeNode root = new  TreeNode(); // 创建树的根节点
         for  (List<String> transRecord : transRecords) {
             LinkedList<String> record = sortByF1(transRecord, F1);
             TreeNode subTreeRoot = root;
             TreeNode tmpRoot = null ;
             if  (root.getChildren() != null ) {
                 while  (!record.isEmpty()
                         && (tmpRoot = subTreeRoot.findChild(record.peek())) != null ) {
                     tmpRoot.countIncrement( 1 );
                     subTreeRoot = tmpRoot;
                     record.poll();
                 }
             }
             addNodes(subTreeRoot, record, F1);
         }
         return  root;
     }
 
     // 把交易记录按项的频繁程序降序排列
     public  LinkedList<String> sortByF1(List<String> transRecord,
             ArrayList<TreeNode> F1) {
         Map<String, Integer> map = new  HashMap<String, Integer>();
         for  (String item : transRecord) {
             // 由于F1已经是按降序排列的,
             for  ( int  i = 0 ; i < F1.size(); i++) {
                 TreeNode tnode = F1.get(i);
                 if  (tnode.getName().equals(item)) {
                     map.put(item, i);
                 }
             }
         }
         ArrayList<Entry<String, Integer>> al = new  ArrayList<Entry<String, Integer>>(
                 map.entrySet());
         Collections.sort(al, new  Comparator<Map.Entry<String, Integer>>() {
             @Override
             public  int  compare(Entry<String, Integer> arg0,
                     Entry<String, Integer> arg1) {
                 // 降序排列
                 return  arg0.getValue() - arg1.getValue();
             }
         });
         LinkedList<String> rest = new  LinkedList<String>();
         for  (Entry<String, Integer> entry : al) {
             rest.add(entry.getKey());
         }
         return  rest;
     }
 
     // 把record作为ancestor的后代插入树中
     public  void  addNodes(TreeNode ancestor, LinkedList<String> record,
             ArrayList<TreeNode> F1) {
         if  (record.size() > 0 ) {
             while  (record.size() > 0 ) {
                 String item = record.poll();
                 TreeNode leafnode = new  TreeNode(item);
                 leafnode.setCount( 1 );
                 leafnode.setParent(ancestor);
                 ancestor.addChild(leafnode);
 
                 for  (TreeNode f1 : F1) {
                     if  (f1.getName().equals(item)) {
                         while  (f1.getNextHomonym() != null ) {
                             f1 = f1.getNextHomonym();
                         }
                         f1.setNextHomonym(leafnode);
                         break ;
                     }
                 }
 
                 addNodes(leafnode, record, F1);
             }
         }
     }
 
     public  static  void  main(String[] args) {
         FPTree fptree = new  FPTree();
         fptree.setMinSuport( 3 );
         List<List<String>> transRecords = fptree
                 .readTransRocords( "/home/orisun/test/market" );
         fptree.FPGrowth(transRecords, null );
     }
}

输入文件

复制代码
牛奶,鸡蛋,面包,薯片
鸡蛋,爆米花,薯片,啤酒
鸡蛋,面包,薯片
牛奶,鸡蛋,面包,爆米花,薯片,啤酒
牛奶,面包,啤酒
鸡蛋,面包,啤酒
牛奶,面包,薯片
牛奶,鸡蛋,面包,黄油,薯片
牛奶,鸡蛋,黄油,薯片
复制代码

输出

复制代码
6    薯片    鸡蛋
5    薯片    面包
5    鸡蛋    面包
4    薯片    鸡蛋    面包
5    薯片    牛奶
5    面包    牛奶
4    鸡蛋    牛奶
4    薯片    面包    牛奶
4    薯片    鸡蛋    牛奶
3    面包    鸡蛋    牛奶
3    薯片    面包    鸡蛋    牛奶
3    鸡蛋    啤酒
3    面包    啤酒
复制代码

用Hadoop来实现

在上面的代码我们把整个事务数据库放在一个List<List<String>>里面传给FPGrowth,在实际中这是不可取的,因为内存不可能容下整个事务数据库,我们可能需要从关系关系数据库中一条一条地读入来建立FP-Tree。但无论如何 FP-Tree是肯定需要放在内存中的,但内存如果容不下怎么办?另外FPGrowth仍然是非常耗时的,你想提高速度怎么办?解决办法:分而治之,并行计算。

我们把原始事务数据库分成N部分,在N个节点上并行地进行FPGrowth挖掘,最后把关联规则汇总到一起就可以了。关键问题是怎么“划分”才会不遗露任何一条关联规则呢?参见这篇博客。这里为了达到并行计算的目的,采用了一种“冗余”的划分方法,即各部分的并集大于原来的集合。这种方法最终求出来的关联规则也是有冗余的,比如在节点1上得到一条规则(6:啤酒,尿布),在节点2上得到一条规则(3:尿布,啤酒),显然节点2上的这条规则是冗余的,需要采用后续步骤把冗余的规则去掉。

代码:

Record.java

?
package  fptree;
 
import  java.io.DataInput;
import  java.io.DataOutput;
import  java.io.IOException;
import  java.util.Collections;
import  java.util.LinkedList;
 
import  org.apache.hadoop.io.WritableComparable;
 
public  class  Record implements  WritableComparable<Record>{
     
     LinkedList<String> list;
     
     public  Record(){
         list= new  LinkedList<String>();
     }
     
     public  Record(String[] arr){
         list= new  LinkedList<String>();
         for ( int  i= 0 ;i<arr.length;i++)
             list.add(arr[i]);
     }
     
     @Override
     public  String toString(){
         String str=list.get( 0 );
         for ( int  i= 1 ;i<list.size();i++)
             str+= "\t" +list.get(i);
         return  str;
     }
 
     @Override
     public  void  readFields(DataInput in) throws  IOException {
         list.clear();
         String line=in.readUTF();
         String []arr=line.split( "\\s+" );
         for ( int  i= 0 ;i<arr.length;i++)
             list.add(arr[i]);
     }
 
     @Override
     public  void  write(DataOutput out) throws  IOException {
         out.writeUTF( this .toString());
     }
 
     @Override
     public  int  compareTo(Record obj) {
         Collections.sort(list);
         Collections.sort(obj.list);
         return  this .toString().compareTo(obj.toString());
     }
 
}

DC_FPTree.java

?
package  fptree;
 
import  java.io.BufferedReader;
import  java.io.IOException;
import  java.io.InputStreamReader;
import  java.util.ArrayList;
import  java.util.BitSet;
import  java.util.Collections;
import  java.util.Comparator;
import  java.util.HashMap;
import  java.util.LinkedList;
import  java.util.List;
import  java.util.Map;
import  java.util.Map.Entry;
import  java.util.Set;
 
import  org.apache.hadoop.conf.Configuration;
import  org.apache.hadoop.conf.Configured;
import  org.apache.hadoop.fs.FSDataInputStream;
import  org.apache.hadoop.fs.FileSystem;
import  org.apache.hadoop.fs.Path;
import  org.apache.hadoop.io.IntWritable;
import  org.apache.hadoop.io.LongWritable;
import  org.apache.hadoop.io.Text;
import  org.apache.hadoop.mapreduce.Job;
import  org.apache.hadoop.mapreduce.Mapper;
import  org.apache.hadoop.mapreduce.Reducer;
import  org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import  org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import  org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import  org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import  org.apache.hadoop.util.Tool;
import  org.apache.hadoop.util.ToolRunner;
 
public  class  DC_FPTree extends  Configured implements  Tool {
 
     private  static  final  int  GroupNum = 10 ;
     private  static  final  int  minSuport= 6 ;
 
     public  static  class  GroupMapper extends
             Mapper<LongWritable, Text, IntWritable, Record> {
         List<String> freq = new  LinkedList<String>(); // 频繁1项集
         List<List<String>> freq_group = new  LinkedList<List<String>>(); // 分组后的频繁1项集
 
         @Override
         public  void  setup(Context context) throws  IOException {
             // 从文件读入频繁1项集
             FileSystem fs = FileSystem.get(context.getConfiguration());
             Path freqFile = new  Path( "/user/orisun/input/F1" );
             FSDataInputStream in = fs.open(freqFile);
             InputStreamReader isr = new  InputStreamReader(in);
             BufferedReader br = new  BufferedReader(isr);
             try  {
                 String line;
                 while  ((line = br.readLine()) != null ) {
                     String[] str = line.split( "\\s+" );
                     String word = str[ 0 ];
                     freq.add(word);
                 }
             } finally  {
                 br.close();
             }
             // 对频繁1项集进行分组
             Collections.shuffle(freq); // 打乱顺序
             int  cap = freq.size() / GroupNum; // 每段分为一组
             for  ( int  i = 0 ; i < GroupNum; i++) {
                 List<String> list = new  LinkedList<String>();
                 for  ( int  j = 0 ; j < cap; j++) {
                     list.add(freq.get(i * cap + j));
                 }
                 freq_group.add(list);
             }
             int  remainder = freq.size() % GroupNum;
             int  base = GroupNum * cap;
             for  ( int  i = 0 ; i < remainder; i++) {
                 freq_group.get(i).add(freq.get(base + i));
             }
         }
 
         @Override
         public  void  map(LongWritable key, Text value, Context context)
                 throws  IOException, InterruptedException {
             String[] arr = value.toString().split( "\\s+" );
             Record record = new  Record(arr);
             LinkedList<String> list = record.list;
             BitSet bs= new  BitSet(freq_group.size());
             bs.clear();
             while  (record.list.size() > 0 ) {
                 String item = list.peekLast(); // 取出record的最后一项
                 int  i= 0 ;
                 for  (; i < freq_group.size(); i++) {
                     if (bs.get(i))
                         continue ;
                     if  (freq_group.get(i).contains(item)) {
                         bs.set(i);
                         break ;
                     }
                 }
                 if (i<freq_group.size()){     //找到了
                     context.write( new  IntWritable(i), record); 
                 }
                 record.list.pollLast();
             }
         }
     }
     
     public  static  class  FPReducer extends  Reducer<IntWritable,Record,IntWritable,Text>{
         public  void  reduce(IntWritable key,Iterable<Record> values,Context context) throws  IOException,InterruptedException{
             List<List<String>> trans= new  LinkedList<List<String>>();
             while (values.iterator().hasNext()){
                 Record record=values.iterator().next();
                 LinkedList<String> list= new  LinkedList<String>();
                 for (String ele:record.list)
                     list.add(ele);
                 trans.add(list);
             }
             FPGrowth(trans, null ,context);
         }
         // FP-Growth算法
     public  void  FPGrowth(List<List<String>> transRecords,
             List<String> postPattern,Context context) throws  IOException, InterruptedException {
         // 构建项头表,同时也是频繁1项集
         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
         // 构建FP-Tree
         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
         // 如果FP-Tree为空则返回
         if  (treeRoot.getChildren()== null  || treeRoot.getChildren().size() == 0 )
             return ;
         //输出项头表的每一项+postPattern
         if (postPattern!= null ){
             for  (TreeNode header : HeaderTable) {
                 String outStr=header.getName();
                 int  count=header.getCount();
                 for  (String ele : postPattern)
                     outStr+= "\t"  + ele;
                 context.write( new  IntWritable(count), new  Text(outStr));
             }
         }
         // 找到项头表的每一项的条件模式基,进入递归迭代
         for  (TreeNode header : HeaderTable) {
             // 后缀模式增加一项
             List<String> newPostPattern = new  LinkedList<String>();
             newPostPattern.add(header.getName());
             if  (postPattern != null )
                 newPostPattern.addAll(postPattern);
             // 寻找header的条件模式基CPB,放入newTransRecords中
             List<List<String>> newTransRecords = new  LinkedList<List<String>>();
             TreeNode backnode = header.getNextHomonym();
             while  (backnode != null ) {
                 int  counter = backnode.getCount();
                 List<String> prenodes = new  ArrayList<String>();
                 TreeNode parent = backnode;
                 // 遍历backnode的祖先节点,放到prenodes中
                 while  ((parent = parent.getParent()).getName() != null ) {
                     prenodes.add(parent.getName());
                 }
                 while  (counter-- > 0 ) {
                     newTransRecords.add(prenodes);
                 }
                 backnode = backnode.getNextHomonym();
             }
             // 递归迭代
             FPGrowth(newTransRecords, newPostPattern,context);
         }
     }
 
         // 构建项头表,同时也是频繁1项集
         public  ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
             ArrayList<TreeNode> F1 = null ;
             if  (transRecords.size() > 0 ) {
                 F1 = new  ArrayList<TreeNode>();
                 Map<String, TreeNode> map = new  HashMap<String, TreeNode>();
                 // 计算事务数据库中各项的支持度
                 for  (List<String> record : transRecords) {
                     for  (String item : record) {
                         if  (!map.keySet().contains(item)) {
                             TreeNode node = new  TreeNode(item);
                             node.setCount( 1 );
                             map.put(item, node);
                         } else  {
                             map.get(item).countIncrement( 1 );
                         }
                     }
                 }
                 // 把支持度大于(或等于)minSup的项加入到F1中
                 Set<String> names = map.keySet();
                 for  (String name : names) {
                     TreeNode tnode = map.get(name);
                     if  (tnode.getCount() >= minSuport) {
                         F1.add(tnode);
                     }
                 }
                 Collections.sort(F1);
                 return  F1;
             } else  {
                 return  null ;
             }
         }
 
         // 构建FP-Tree
         public  TreeNode buildFPTree(List<List<String>> transRecords,
                 ArrayList<TreeNode> F1) {
             TreeNode root = new  TreeNode(); // 创建树的根节点
             for  (List<String> transRecord : transRecords) {
                 LinkedList<String> record = sortByF1(transRecord, F1);
                 TreeNode subTreeRoot = root;
                 TreeNode tmpRoot = null ;
                 if  (root.getChildren() != null ) {
                     while  (!record.isEmpty()
                             && (tmpRoot = subTreeRoot.findChild(record.peek())) != null ) {
                         tmpRoot.countIncrement( 1 );
                         subTreeRoot = tmpRoot;
                         record.poll();
                     }
                 }
                 addNodes(subTreeRoot, record, F1);
             }
             return  root;
         }
 
         // 把交易记录按项的频繁程序降序排列
         public  LinkedList<String> sortByF1(List<String> transRecord,
                 ArrayList<TreeNode> F1) {
             Map<String, Integer> map = new  HashMap<String, Integer>();
             for  (String item : transRecord) {
                 // 由于F1已经是按降序排列的,
                 for  ( int  i = 0 ; i < F1.size(); i++) {
                     TreeNode tnode = F1.get(i);
                     if  (tnode.getName().equals(item)) {
                         map.put(item, i);
                     }
                 }
             }
             ArrayList<Entry<String, Integer>> al = new  ArrayList<Entry<String, Integer>>(
                     map.entrySet());
             Collections.sort(al, new  Comparator<Map.Entry<String, Integer>>() {
                 @Override
                 public  int  compare(Entry<String, Integer> arg0,
                         Entry<String, Integer> arg1) {
                     // 降序排列
                     return  arg0.getValue() - arg1.getValue();
                 }
             });
             LinkedList<String> rest = new  LinkedList<String>();
             for  (Entry<String, Integer> entry : al) {
                 rest.add(entry.getKey());
             }
             return  rest;
         }
 
         // 把record作为ancestor的后代插入树中
         public  void  addNodes(TreeNode ancestor, LinkedList<String> record,
                 ArrayList<TreeNode> F1) {
             if  (record.size() > 0 ) {
                 while  (record.size() > 0 ) {
                     String item = record.poll();
                     TreeNode leafnode = new  TreeNode(item);
                     leafnode.setCount( 1 );
                     leafnode.setParent(ancestor);
                     ancestor.addChild(leafnode);
 
                     for  (TreeNode f1 : F1) {
                         if  (f1.getName().equals(item)) {
                             while  (f1.getNextHomonym() != null ) {
                                 f1 = f1.getNextHomonym();
                             }
                             f1.setNextHomonym(leafnode);
                             break ;
                         }
                     }
 
                     addNodes(leafnode, record, F1);
                 }
             }
         }
     }
     
     public  static  class  InverseMapper extends
             Mapper<LongWritable, Text, Record, IntWritable> {
         @Override
         public  void  map(LongWritable key, Text value, Context context)
                 throws  IOException, InterruptedException {
             String []arr=value.toString().split( "\\s+" );
             int  count=Integer.parseInt(arr[ 0 ]);
             Record record= new  Record();
             for ( int  i= 1 ;i<arr.length;i++){
                 record.list.add(arr[i]);
             }
             context.write(record, new  IntWritable(count));
         }
     }
     
     public  static  class  MaxReducer extends  Reducer<Record,IntWritable,IntWritable,Record>{
         public  void  reduce(Record key,Iterable<IntWritable> values,Context context) throws  IOException,InterruptedException{
             int  max=- 1 ;
             for (IntWritable value:values){
                 int  i=value.get();
                 if (i>max)
                     max=i;
             }
             context.write( new  IntWritable(max), key);
         }
     }
 
 
     @Override
     public  int  run(String[] arg0) throws  Exception {
         Configuration conf=getConf();
         conf.set( "mapred.task.timeout" , "6000000" );
         Job job= new  Job(conf);
         job.setJarByClass(DC_FPTree. class );
         FileSystem fs=FileSystem.get(getConf());
         
         FileInputFormat.setInputPaths(job, "/user/orisun/input/data" );
         Path outDir= new  Path( "/user/orisun/output" );
         fs.delete(outDir, true );
         FileOutputFormat.setOutputPath(job, outDir);
         
         job.setMapperClass(GroupMapper. class );
         job.setReducerClass(FPReducer. class );
         
         job.setInputFormatClass(TextInputFormat. class );
         job.setOutputFormatClass(TextOutputFormat. class );
         job.setMapOutputKeyClass(IntWritable. class );
         job.setMapOutputValueClass(Record. class );
         job.setOutputKeyClass(IntWritable. class );
         job.setOutputKeyClass(Text. class );
         
         boolean  success=job.waitForCompletion( true );
         
         job= new  Job(conf);
         job.setJarByClass(DC_FPTree. class );
         
         FileInputFormat.setInputPaths(job, "/user/orisun/output/part-r-*" );
         Path outDir2= new  Path( "/user/orisun/output2" );
         fs.delete(outDir2, true );
         FileOutputFormat.setOutputPath(job, outDir2);
         
         job.setMapperClass(InverseMapper. class );
         job.setReducerClass(MaxReducer. class );
         //job.setNumReduceTasks(0);
         
         job.setInputFormatClass(TextInputFormat. class );
         job.setOutputFormatClass(TextOutputFormat. class );
         job.setMapOutputKeyClass(Record. class );
         job.setMapOutputValueClass(IntWritable. class );
         job.setOutputKeyClass(IntWritable. class );
         job.setOutputKeyClass(Record. class );
         
         success |= job.waitForCompletion( true );
         
         return  success? 0 : 1 ;
     }
 
     public  static  void  main(String[] args) throws  Exception{
         int  res=ToolRunner.run( new  Configuration(), new  DC_FPTree(), args);
         System.exit(res);
     }
}

结束语

在实践中,关联规则挖掘可能并不像人们期望的那么有用。一方面是因为支持度置信度框架会产生过多的规则,并不是每一个规则都是有用的。另一方面大部分的关联规则并不像“啤酒与尿布”这种经典故事这么普遍。关联规则分析是需要技巧的,有时需要用更严格的统计学知识来控制规则的增殖。 

原文来自:博客园(华夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值