单机和集群环境下的FP-Growth算法java实现(关联规则挖掘)

1 FP-Growth简要描述

     和Apriori算法一样,都是用于关联规则挖掘的算法。Apriori算法每生成一次k频繁项集都需要遍历一次事务数据库,当事务数据库很大时会有频繁的I/O操作,因此只适合找出小数据集的频繁项集;而FP-Growth算法整个过程中,只有两次扫描事务数据库,一次发生在数据预处理(包括去掉事务的ID编号、合并相同事务等),另一次发生在构造FP-Tree的头项表,因此该种算法对于大数据集效率也很高。FP-Growth算法的步骤主要有:数据预处理、构造头项表(需要筛选出满足最小支持度的item)、构造频繁树、接下来就是遍历头项表,递归得到所有模式基,所有频繁项集。

2 FP-Growth单机java实现源码

  1. <span style=“font-size:14px;”>package org.min.ml;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.util.ArrayList;  
  7. import java.util.Arrays;  
  8. import java.util.Collections;  
  9. import java.util.HashMap;  
  10. import java.util.LinkedList;  
  11. import java.util.List;  
  12. import java.util.Map;  
  13.   
  14. /** 
  15.  * FP-tree算法:用于挖掘出事务数据库中的频繁项集,该方法是对APriori算法的改进 
  16.  *  
  17.  * @author ShiMin 
  18.  * @date   2015/10/17  
  19.  */  
  20. public class FPTree  
  21. {  
  22.     private int minSupport;//最小支持度  
  23.       
  24.     public FPTree(int support)  
  25.     {  
  26.         this.minSupport = support;  
  27.     }  
  28.       
  29.     /** 
  30.      * 加载事务数据库   
  31.      * @param file 文件路径名  文件中每行item由空格分隔 
  32.      */  
  33.     public List<List<String>> loadTransaction(String file)  
  34.     {  
  35.         List<List<String>> transactions = new ArrayList<List<String>>();  
  36.         try  
  37.         {  
  38.             BufferedReader br = new BufferedReader(new FileReader(new File(file)));  
  39.             String line = ”“;  
  40.             while((line = br.readLine()) != null)  
  41.             {  
  42.                 transactions.add(Arrays.asList(line.split(” ”)));  
  43.             }  
  44.         } catch (Exception e)  
  45.         {  
  46.             e.printStackTrace();  
  47.         }  
  48.         return transactions;  
  49.     }  
  50.       
  51.     public void FPGrowth(List<List<String>> transactions, List<String> postPattern)  
  52.     {  
  53.         //构建头项表  
  54.         List<TNode> headerTable = buildHeaderTable(transactions);  
  55.         //构建FP树  
  56.         TNode tree = bulidFPTree(headerTable, transactions);  
  57.         //当树为空时退出  
  58.         if (tree.getChildren()== null || tree.getChildren().size() == 0)  
  59.         {  
  60.             return;  
  61.         }  
  62.         //输出频繁项集  
  63.         if(postPattern!=null)  
  64.         {  
  65.             for (TNode head : headerTable)   
  66.             {  
  67.                 System.out.print(head.getCount() + ” ” + head.getItemName());  
  68.                 for (String item : postPattern)  
  69.                 {  
  70.                     System.out.print(” ” + item);  
  71.                 }  
  72.                 System.out.println();  
  73.             }  
  74.         }  
  75.         //遍历每一个头项表节点   
  76.         for(TNode head : headerTable)  
  77.         {  
  78.             List<String> newPostPattern = new LinkedList<String>();  
  79.             newPostPattern.add(head.getItemName());//添加本次模式基  
  80.             //加上将前面累积的前缀模式基  
  81.             if (postPattern != null)  
  82.             {  
  83.                 newPostPattern.addAll(postPattern);  
  84.             }  
  85.             //定义新的事务数据库  
  86.             List<List<String>> newTransaction = new LinkedList<List<String>>();  
  87.             TNode nextnode = head.getNext();  
  88.             //去除名称为head.getItemName()的模式基,构造新的事务数据库  
  89.             while(nextnode != null)  
  90.             {  
  91.                 int count = nextnode.getCount();  
  92.                 List<String> parentNodes = new ArrayList<String>();//nextnode节点的所有祖先节点  
  93.                 TNode parent = nextnode.getParent();  
  94.                 while(parent.getItemName() != null)  
  95.                 {  
  96.                     parentNodes.add(parent.getItemName());  
  97.                     parent = parent.getParent();  
  98.                 }  
  99.                 //向事务数据库中重复添加count次parentNodes  
  100.                 while((count–) > 0)  
  101.                 {  
  102.                     newTransaction.add(parentNodes);//添加模式基的前缀 ,因此最终的频繁项为:  parentNodes -> newPostPattern  
  103.                 }  
  104.                 //下一个同名节点  
  105.                 nextnode = nextnode.getNext();  
  106.             }  
  107.             //每个头项表节点重复上述所有操作,递归  
  108.             FPGrowth(newTransaction, newPostPattern);  
  109.         }  
  110.     }  
  111.       
  112.     /** 
  113.      * 构建头项表,按递减排好序 
  114.      * @return 
  115.      */  
  116.     public List<TNode> buildHeaderTable(List<List<String>> transactions)  
  117.     {  
  118.         List<TNode> list = new ArrayList<TNode>();  
  119.         Map<String,TNode> nodesmap = new HashMap<String,TNode>();  
  120.         //为每一个item构建一个节点  
  121.         for(List<String> lines : transactions)  
  122.         {  
  123.             for(int i = 0; i < lines.size(); ++i)  
  124.             {  
  125.                 String itemName = lines.get(i);  
  126.                 if(!nodesmap.keySet().contains(itemName)) //为item构建节点  
  127.                 {  
  128.                     nodesmap.put(itemName, new TNode(itemName));  
  129.                 }  
  130.                 else //若已经构建过该节点,出现次数加1  
  131.                 {  
  132.                     nodesmap.get(itemName).increaseCount(1);  
  133.                 }  
  134.             }  
  135.         }  
  136.         //筛选满足最小支持度的item节点  
  137.         for(TNode item : nodesmap.values())  
  138.         {  
  139.             if(item.getCount() >= minSupport)  
  140.             {  
  141.                 list.add(item);  
  142.             }  
  143.         }  
  144.         //按count值从高到低排序  
  145.         Collections.sort(list);  
  146.         return list;  
  147.     }  
  148.       
  149.     /** 
  150.      * 构建FR-tree 
  151.      * @param headertable 头项表 
  152.      * @return  
  153.      */  
  154.     public TNode bulidFPTree(List<TNode> headertable, List<List<String>> transactions)  
  155.     {  
  156.         TNode rootNode = new TNode();  
  157.         for(List<String> items : transactions)  
  158.         {  
  159.             LinkedList<String> itemsDesc = sortItemsByDesc(items, headertable);  
  160.             //寻找添加itemsDesc为子树的父节点  
  161.             TNode subtreeRoot = rootNode;  
  162.             if(subtreeRoot.getChildren().size() != 0)  
  163.             {  
  164.                 TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  165.                 while(!itemsDesc.isEmpty() && tempNode != null)  
  166.                 {  
  167.                     tempNode.increaseCount(1);  
  168.                     subtreeRoot = tempNode;  
  169.                     itemsDesc.poll();  
  170.                     tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  171.                 }  
  172.             }  
  173.             //将itemsDesc中剩余的节点加入作为subtreeRoot的子树  
  174.             addSubTree(headertable, subtreeRoot, itemsDesc);  
  175.         }  
  176.         return rootNode;  
  177.     }  
  178.       
  179.     /** 
  180.      * @param headertable 头项表 
  181.      * @param subtreeRoot 子树父节点 
  182.      * @param itemsDesc 被添加的子树 
  183.      */  
  184.     public void addSubTree(List<TNode> headertable, TNode subtreeRoot, LinkedList<String> itemsDesc)  
  185.     {  
  186.         if(itemsDesc.size() > 0)  
  187.         {  
  188.             TNode thisnode = new TNode(itemsDesc.pop());//构建新节点  
  189.             subtreeRoot.getChildren().add(thisnode);  
  190.             thisnode.setParent(subtreeRoot);  
  191.             //将thisnode加入头项表对应节点链表的末尾  
  192.             for(TNode node : headertable)  
  193.             {  
  194.                 if(node.getItemName().equals(thisnode.getItemName()))  
  195.                 {  
  196.                     TNode lastNode = node;  
  197.                     while(lastNode.getNext() != null)  
  198.                     {  
  199.                         lastNode = lastNode.getNext();  
  200.                     }  
  201.                     lastNode.setNext(thisnode);  
  202.                 }  
  203.             }  
  204.             subtreeRoot = thisnode;//更新父节点为当前节点  
  205.             //递归添加剩余的items  
  206.             addSubTree(headertable, subtreeRoot, itemsDesc);  
  207.         }  
  208.     }  
  209.       
  210.     //将items按count从高到低排序  
  211.     public LinkedList<String> sortItemsByDesc(List<String> items, List<TNode> headertable)  
  212.     {  
  213.         LinkedList<String> itemsDesc = new LinkedList<String>();  
  214.         for(TNode node : headertable)  
  215.         {  
  216.             if(items.contains(node.getItemName()))  
  217.             {  
  218.                 itemsDesc.add(node.getItemName());  
  219.             }  
  220.         }  
  221.         return itemsDesc;  
  222.     }  
  223.       
  224.     public static void main(String[] args)  
  225.     {  
  226.         FPTree fptree = new FPTree(4);  
  227.         List<List<String>> transactions = fptree.loadTransaction(”C:\\Users\\shimin\\Desktop\\新建文件夹\\wordcounts.txt”);  
  228.         fptree.FPGrowth(transactions, null);  
  229.     }  
  230.       
  231.     /** 
  232.      * fp-tree节点的数据结构(一个item表示一个节点) 
  233.      * @author shimin 
  234.      * 
  235.      */  
  236.     public class TNode implements Comparable<TNode>  
  237.     {  
  238.         private String itemName; //项目名  
  239.         private int count; //事务数据库中出现次数  
  240.         private TNode parent; //父节点  
  241.         private List<TNode> children; //子节点  
  242.         private TNode next;//下一个同名节点  
  243.           
  244.         public TNode()  
  245.         {  
  246.             this.children = new ArrayList<TNode>();  
  247.         }  
  248.         public TNode(String name)  
  249.         {  
  250.             this.itemName = name;  
  251.             this.count = 1;  
  252.             this.children = new ArrayList<TNode>();  
  253.         }  
  254.         public TNode findChildren(String childName)  
  255.         {  
  256.             for(TNode node : this.getChildren())  
  257.             {  
  258.                 if(node.getItemName().equals(childName))  
  259.                 {  
  260.                     return node;  
  261.                 }  
  262.             }  
  263.             return null;  
  264.         }  
  265.         public TNode getNext()  
  266.         {  
  267.             return next;  
  268.         }  
  269.         public TNode getParent()  
  270.         {  
  271.             return parent;  
  272.         }  
  273.         public void setNext(TNode next)  
  274.         {  
  275.             this.next = next;  
  276.         }  
  277.         public void increaseCount(int num)  
  278.         {  
  279.             count += num;  
  280.         }  
  281.         public int getCount()  
  282.         {  
  283.             return count;  
  284.         }  
  285.         public String getItemName()  
  286.         {  
  287.             return itemName;  
  288.         }  
  289.         public List<TNode> getChildren()  
  290.         {  
  291.             return children;  
  292.         }  
  293.         public void setParent(TNode parent)  
  294.         {  
  295.             this.parent = parent;  
  296.         }  
  297.         @Override  
  298.         public int compareTo(TNode o)  
  299.         {  
  300.             return o.getCount() - this.getCount();  
  301.         }  
  302.     }  
  303. }</span>  
<span style="font-size:14px;">package org.min.ml;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * FP-tree算法:用于挖掘出事务数据库中的频繁项集,该方法是对APriori算法的改进
 * 
 * @author ShiMin
 * @date   2015/10/17 
 */
public class FPTree
{
    private int minSupport;//最小支持度

    public FPTree(int support)
    {
        this.minSupport = support;
    }

    /**
     * 加载事务数据库  
     * @param file 文件路径名  文件中每行item由空格分隔
     */
    public List<List<String>> loadTransaction(String file)
    {
        List<List<String>> transactions = new ArrayList<List<String>>();
        try
        {
            BufferedReader br = new BufferedReader(new FileReader(new File(file)));
            String line = "";
            while((line = br.readLine()) != null)
            {
                transactions.add(Arrays.asList(line.split(" ")));
            }
        } catch (Exception e)
        {
            e.printStackTrace();
        }
        return transactions;
    }

    public void FPGrowth(List<List<String>> transactions, List<String> postPattern)
    {
        //构建头项表
        List<TNode> headerTable = buildHeaderTable(transactions);
        //构建FP树
        TNode tree = bulidFPTree(headerTable, transactions);
        //当树为空时退出
        if (tree.getChildren()== null || tree.getChildren().size() == 0)
        {
            return;
        }
        //输出频繁项集
        if(postPattern!=null)
        {
            for (TNode head : headerTable) 
            {
                System.out.print(head.getCount() + " " + head.getItemName());
                for (String item : postPattern)
                {
                    System.out.print(" " + item);
                }
                System.out.println();
            }
        }
        //遍历每一个头项表节点 
        for(TNode head : headerTable)
        {
            List<String> newPostPattern = new LinkedList<String>();
            newPostPattern.add(head.getItemName());//添加本次模式基
            //加上将前面累积的前缀模式基
            if (postPattern != null)
            {
                newPostPattern.addAll(postPattern);
            }
            //定义新的事务数据库
            List<List<String>> newTransaction = new LinkedList<List<String>>();
            TNode nextnode = head.getNext();
            //去除名称为head.getItemName()的模式基,构造新的事务数据库
            while(nextnode != null)
            {
                int count = nextnode.getCount();
                List<String> parentNodes = new ArrayList<String>();//nextnode节点的所有祖先节点
                TNode parent = nextnode.getParent();
                while(parent.getItemName() != null)
                {
                    parentNodes.add(parent.getItemName());
                    parent = parent.getParent();
                }
                //向事务数据库中重复添加count次parentNodes
                while((count--) > 0)
                {
                    newTransaction.add(parentNodes);//添加模式基的前缀 ,因此最终的频繁项为:  parentNodes -> newPostPattern
                }
                //下一个同名节点
                nextnode = nextnode.getNext();
            }
            //每个头项表节点重复上述所有操作,递归
            FPGrowth(newTransaction, newPostPattern);
        }
    }

    /**
     * 构建头项表,按递减排好序
     * @return
     */
    public List<TNode> buildHeaderTable(List<List<String>> transactions)
    {
        List<TNode> list = new ArrayList<TNode>();
        Map<String,TNode> nodesmap = new HashMap<String,TNode>();
        //为每一个item构建一个节点
        for(List<String> lines : transactions)
        {
            for(int i = 0; i < lines.size(); ++i)
            {
                String itemName = lines.get(i);
                if(!nodesmap.keySet().contains(itemName)) //为item构建节点
                {
                    nodesmap.put(itemName, new TNode(itemName));
                }
                else //若已经构建过该节点,出现次数加1
                {
                    nodesmap.get(itemName).increaseCount(1);
                }
            }
        }
        //筛选满足最小支持度的item节点
        for(TNode item : nodesmap.values())
        {
            if(item.getCount() >= minSupport)
            {
                list.add(item);
            }
        }
        //按count值从高到低排序
        Collections.sort(list);
        return list;
    }

    /**
     * 构建FR-tree
     * @param headertable 头项表
     * @return 
     */
    public TNode bulidFPTree(List<TNode> headertable, List<List<String>> transactions)
    {
        TNode rootNode = new TNode();
        for(List<String> items : transactions)
        {
            LinkedList<String> itemsDesc = sortItemsByDesc(items, headertable);
            //寻找添加itemsDesc为子树的父节点
            TNode subtreeRoot = rootNode;
            if(subtreeRoot.getChildren().size() != 0)
            {
                TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());
                while(!itemsDesc.isEmpty() && tempNode != null)
                {
                    tempNode.increaseCount(1);
                    subtreeRoot = tempNode;
                    itemsDesc.poll();
                    tempNode = subtreeRoot.findChildren(itemsDesc.peek());
                }
            }
            //将itemsDesc中剩余的节点加入作为subtreeRoot的子树
            addSubTree(headertable, subtreeRoot, itemsDesc);
        }
        return rootNode;
    }

    /**
     * @param headertable 头项表
     * @param subtreeRoot 子树父节点
     * @param itemsDesc 被添加的子树
     */
    public void addSubTree(List<TNode> headertable, TNode subtreeRoot, LinkedList<String> itemsDesc)
    {
        if(itemsDesc.size() > 0)
        {
            TNode thisnode = new TNode(itemsDesc.pop());//构建新节点
            subtreeRoot.getChildren().add(thisnode);
            thisnode.setParent(subtreeRoot);
            //将thisnode加入头项表对应节点链表的末尾
            for(TNode node : headertable)
            {
                if(node.getItemName().equals(thisnode.getItemName()))
                {
                    TNode lastNode = node;
                    while(lastNode.getNext() != null)
                    {
                        lastNode = lastNode.getNext();
                    }
                    lastNode.setNext(thisnode);
                }
            }
            subtreeRoot = thisnode;//更新父节点为当前节点
            //递归添加剩余的items
            addSubTree(headertable, subtreeRoot, itemsDesc);
        }
    }

    //将items按count从高到低排序
    public LinkedList<String> sortItemsByDesc(List<String> items, List<TNode> headertable)
    {
        LinkedList<String> itemsDesc = new LinkedList<String>();
        for(TNode node : headertable)
        {
            if(items.contains(node.getItemName()))
            {
                itemsDesc.add(node.getItemName());
            }
        }
        return itemsDesc;
    }

    public static void main(String[] args)
    {
        FPTree fptree = new FPTree(4);
        List<List<String>> transactions = fptree.loadTransaction("C:\\Users\\shimin\\Desktop\\新建文件夹\\wordcounts.txt");
        fptree.FPGrowth(transactions, null);
    }

    /**
     * fp-tree节点的数据结构(一个item表示一个节点)
     * @author shimin
     *
     */
    public class TNode implements Comparable<TNode>
    {
        private String itemName; //项目名
        private int count; //事务数据库中出现次数
        private TNode parent; //父节点
        private List<TNode> children; //子节点
        private TNode next;//下一个同名节点

        public TNode()
        {
            this.children = new ArrayList<TNode>();
        }
        public TNode(String name)
        {
            this.itemName = name;
            this.count = 1;
            this.children = new ArrayList<TNode>();
        }
        public TNode findChildren(String childName)
        {
            for(TNode node : this.getChildren())
            {
                if(node.getItemName().equals(childName))
                {
                    return node;
                }
            }
            return null;
        }
        public TNode getNext()
        {
            return next;
        }
        public TNode getParent()
        {
            return parent;
        }
        public void setNext(TNode next)
        {
            this.next = next;
        }
        public void increaseCount(int num)
        {
            count += num;
        }
        public int getCount()
        {
            return count;
        }
        public String getItemName()
        {
            return itemName;
        }
        public List<TNode> getChildren()
        {
            return children;
        }
        public void setParent(TNode parent)
        {
            this.parent = parent;
        }
        @Override
        public int compareTo(TNode o)
        {
            return o.getCount() - this.getCount();
        }
    }
}</span>

3 FP-Growth在spark集群上java实现源码

  1. <span style=“font-size:14px;”>package org.min.fpgrowth;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Arrays;  
  5. import java.util.LinkedList;  
  6. import java.util.List;  
  7.   
  8. import org.apache.log4j.Level;  
  9. import org.apache.log4j.Logger;  
  10. import org.apache.spark.SparkConf;  
  11. import org.apache.spark.api.java.JavaPairRDD;  
  12. import org.apache.spark.api.java.JavaRDD;  
  13. import org.apache.spark.api.java.JavaSparkContext;  
  14. import org.apache.spark.api.java.function.Function;  
  15. import org.apache.spark.api.java.function.Function2;  
  16. import org.apache.spark.api.java.function.PairFlatMapFunction;  
  17. import org.apache.spark.api.java.function.PairFunction;  
  18.   
  19. import scala.Tuple2;  
  20. /** 
  21.  * @author ShiMin 
  22.  * @date   2015/10/19 
  23.  * @description FPGrowth algorithm runs on spark in java. 
  24.  */  
  25. public class FPTree  
  26. {  
  27.     public static int SUPPORT_DEGREE = 4;//the support of FPGrowth algorithm  
  28.     public static String SEPARATOR = “ ”;//line separator  
  29.       
  30.     public static void main(String[] args)  
  31.     {  
  32.         Logger.getLogger(”org.apache.spark”).setLevel(Level.OFF);  
  33.         args = new String[]{“hdfs://master:9000/data/input/wordcounts.txt”“hdfs://master:9000/data/output”};  
  34.         if(args.length != 2)  
  35.         {  
  36.             System.err.println(”USage:<Datapath> <Output>”);  
  37.             System.exit(1);  
  38.         }  
  39.           
  40.         SparkConf sparkConf = new SparkConf().setAppName(“frequent parttern growth”).setMaster(“local[4]”);  
  41.         JavaSparkContext ctx = new JavaSparkContext(sparkConf);  
  42.           
  43.         //load the transactions data.  
  44.         JavaRDD<String> lines = ctx.textFile(args[0], 1)  
  45.         //remove the ID of transaction.  
  46.         .map(new Function<String, String>()  
  47.         {  
  48.             private static final long serialVersionUID = -692074104680411557L;  
  49.   
  50.             public String call(String arg0) throws Exception  
  51.             {  
  52.                 return arg0.substring(arg0.indexOf(“ ”) + 1).trim();  
  53.             }  
  54.         });  
  55.           
  56.         JavaPairRDD<String, Integer> transactions = constructTransactions(lines);  
  57.         //run FPGrowth algorithm  
  58.         FPGrowth(transactions, null, ctx);  
  59.         //close sparkContext  
  60.         ctx.close();  
  61.     }  
  62.       
  63.     public static JavaPairRDD<String, Integer> constructTransactions(JavaRDD<String> lines)  
  64.     {  
  65.         JavaPairRDD<String, Integer> transactions = lines  
  66.                 //convert lines to <key,value>(or <line,1>) pairs.  
  67.                 .mapToPair(new PairFunction<String, String, Integer>()  
  68.                 {  
  69.                     private static final long serialVersionUID = 5915574436834522347L;  
  70.   
  71.                     public Tuple2<String, Integer> call(String arg0) throws Exception  
  72.                     {  
  73.                         return new Tuple2<String, Integer>(arg0, 1);  
  74.                     }  
  75.                 })  
  76.                 //combine the same translations.  
  77.                 .reduceByKey(new Function2<Integer, Integer, Integer>()  
  78.                 {  
  79.                     private static final long serialVersionUID = -8075378913994858824L;  
  80.   
  81.                     public Integer call(Integer arg0, Integer arg1) throws Exception  
  82.                     {  
  83.                         return arg0 + arg1;  
  84.                     }  
  85.                 });  
  86.         return transactions;  
  87.     }  
  88.     /** 
  89.      * @param transactions  
  90.      * @param postPattern   
  91.      */  
  92.     public static void FPGrowth(JavaPairRDD<String, Integer> transactions, final List<String> postPattern, JavaSparkContext ctx)  
  93.     {  
  94.         //construct headTable  
  95.         JavaRDD<TNode> headTable = bulidHeadTable(transactions);  
  96.         List<TNode> headlist = headTable.collect();  
  97.         //construct FPTree  
  98.         TNode tree = bulidFPTree(headlist, transactions);  
  99.         //when the FPTree is empty, then exit the excursion  
  100.         if(tree.getChildren() == null || tree.getChildren().size() == 0)  
  101.         {  
  102.             return;  
  103.         }  
  104.         //output the frequent itemSet  
  105.         if(postPattern != null)  
  106.         {  
  107.             for(TNode head : headlist)  
  108.             {  
  109.                 System.out.print(head.getCount() + ” ” + head.getItemName());  
  110.                 for(String item : postPattern)  
  111.                 {  
  112.                     System.out.print(” ” + item);  
  113.                 }  
  114.                 System.out.println();  
  115.             }  
  116. //          headTable.foreach(new VoidFunction<TNode>()  
  117. //          {  
  118. //              public void call(TNode head) throws Exception  
  119. //              {  
  120. //                  System.out.println(head.getCount() + ” ” + head.getItemName());  
  121. //                  for(String item : postPattern)  
  122. //                  {  
  123. //                      System.out.println(“ ” + item);  
  124. //                  }  
  125. //              }  
  126. //          });  
  127.         }  
  128.         //traverse each item of headTable  
  129.         for(TNode head : headlist)  
  130.         {  
  131.             List<String> newPostPattern = new ArrayList<String>();  
  132.             newPostPattern.add(head.getItemName());  
  133.             if(postPattern != null)  
  134.             {  
  135.                 newPostPattern.addAll(postPattern);  
  136.             }  
  137.             //create new transactions  
  138.             List<String> newTransactionsList = new ArrayList<String>();  
  139.             TNode nextNode = head.getNext();  
  140.             while(nextNode != null)  
  141.             {  
  142.                 int count = head.getCount();  
  143.                 TNode parent = nextNode.getParent();  
  144.                 String tlines = ”“;  
  145.                 while(parent.getItemName() != null)  
  146.                 {  
  147.                     tlines += parent.getItemName() + ” ”;  
  148.                     parent = parent.getParent();  
  149.                 }  
  150.                 while((count–) > 0 && !tlines.equals(“”))  
  151.                 {  
  152.                     newTransactionsList.add(tlines);  
  153.                 }  
  154.                 nextNode = nextNode.getNext();  
  155.             }  
  156.             JavaPairRDD<String, Integer> newTransactions = constructTransactions(ctx.parallelize(newTransactionsList));  
  157.             FPGrowth(newTransactions, newPostPattern, ctx);  
  158.         }  
  159.     }  
  160.       
  161.     /** 
  162.      * construct FPTree 
  163.      * @return the root of FPTree 
  164.      */  
  165.     public static TNode bulidFPTree(List<TNode> headTable, JavaPairRDD<String, Integer> transactions)  
  166.     {  
  167.         //create the root node of FPTree  
  168.         final TNode rootNode = new TNode();  
  169.           
  170.         final List<TNode> headItems = headTable;  
  171.         //convert to transactions which ordered by count DESC and items satisfy the minimum support_degree   
  172.         JavaPairRDD<LinkedList<String>, Integer> transactionsDesc = transactions.mapToPair(new PairFunction<Tuple2<String,Integer>, LinkedList<String>, Integer>()  
  173.         {  
  174.             private static final long serialVersionUID = 4787405828748201473L;  
  175.   
  176.             public Tuple2<LinkedList<String>, Integer> call(Tuple2<String, Integer> t)  
  177.                     throws Exception  
  178.             {  
  179.                 LinkedList<String> descItems = new LinkedList<String>();  
  180.                 List<String> items = Arrays.asList(t._1.split(SEPARATOR));  
  181.                 for(TNode node : headItems)  
  182.                 {  
  183.                     String headName = node.getItemName();  
  184.                     if(items.contains(headName))  
  185.                     {  
  186.                         descItems.add(headName);  
  187.                     }  
  188.                 }  
  189.                 return new Tuple2<LinkedList<String>, Integer>(descItems, t._2);  
  190.             }  
  191.         })  
  192.         .filter(new Function<Tuple2<LinkedList<String>,Integer>, Boolean>()  
  193.         {  
  194.             private static final long serialVersionUID = -8157084572151575538L;  
  195.   
  196.             public Boolean call(Tuple2<LinkedList<String>, Integer> v1) throws Exception  
  197.             {  
  198.                 return v1._1.size() > 0;  
  199.             }  
  200.         });  
  201.         List<Tuple2<LinkedList<String>, Integer>> lines = transactionsDesc.collect();  
  202.         //add each transaction to FPTree  
  203.         for(Tuple2<LinkedList<String>, Integer> t : lines)  
  204.         {  
  205.             LinkedList<String> itemsDesc = t._1;//items to be added to FPTree  
  206.             int count = t._2;//how many times itemsDesc to be added to FPTree  
  207.             //find out the root node which add List<String> as subtree  
  208.             TNode subtreeRoot = rootNode;  
  209.             if(subtreeRoot.getChildren().size() != 0)  
  210.             {  
  211.                 TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  212.                 while(!itemsDesc.isEmpty() && tempNode != null)  
  213.                 {  
  214.                     tempNode.countIncrement(count);  
  215.                     subtreeRoot = tempNode;  
  216.                     itemsDesc.poll();  
  217.                     tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  218.                 }  
  219.             }  
  220.             //add the left items of List<String> to FPTree  
  221.             addSubTree(headItems, subtreeRoot, itemsDesc, count);  
  222.         }  
  223.           
  224. //      transactionsDesc.foreach(new VoidFunction<Tuple2<LinkedList<String>,Integer>>()  
  225. //      {  
  226. //          private static final long serialVersionUID = 8054620367976985036L;  
  227. //  
  228. //          public void call(Tuple2<LinkedList<String>, Integer> t) throws Exception  
  229. //          {  
  230. //              LinkedList<String> itemsDesc = t._1;//items to be added to FPTree  
  231. //              int count = t._2;//how many times itemsDesc to be added to FPTree  
  232. //              //find out the root node which add List<String> as subtree  
  233. //              TNode subtreeRoot = rootNode;  
  234. //              if(subtreeRoot.getChildren().size() != 0)  
  235. //              {  
  236. //                  TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  237. //                  while(!itemsDesc.isEmpty() && tempNode != null)  
  238. //                  {  
  239. //                      tempNode.countIncrement(count * 2);  
  240. //                      subtreeRoot = tempNode;  
  241. //                      itemsDesc.poll();  
  242. //                      tempNode = subtreeRoot.findChildren(itemsDesc.peek());  
  243. //                  }  
  244. //              }  
  245. //              //add the left items of List<String> to FPTree  
  246. //              addSubTree(headItems, subtreeRoot, itemsDesc, count);  
  247. //          }  
  248. //      });  
  249.         return rootNode;  
  250.     }  
  251.     /** 
  252.      *  
  253.      * @param headTable 
  254.      * @param subtreeRoot 
  255.      * @param itemsDesc 
  256.      * @param count 
  257.      */  
  258.     public static void addSubTree(List<TNode> headItems, TNode subtreeRoot, LinkedList<String> itemsDesc, int count)  
  259.     {  
  260.         if(itemsDesc.size() > 0)  
  261.         {  
  262.             final TNode thisNode = new TNode(itemsDesc.pop(), count);//construct a new node  
  263.             subtreeRoot.getChildren().add(thisNode);  
  264.             thisNode.setParent(subtreeRoot);  
  265.             //add thisNode to the relevant headTable node list  
  266.             for(TNode t : headItems)  
  267.             {  
  268.                 if(t.getItemName().equals(thisNode.getItemName()))  
  269.                 {  
  270.                     TNode lastNode = t;  
  271.                     while(lastNode.getNext() != null)  
  272.                     {  
  273.                         lastNode = lastNode.getNext();  
  274.                     }  
  275.                     lastNode.setNext(thisNode);  
  276.                 }  
  277.             }  
  278.             subtreeRoot = thisNode;//update thisNode as the current subtreeRoot  
  279.             //add the left items in itemsDesc recursively  
  280.             addSubTree(headItems, subtreeRoot, itemsDesc, count);  
  281.         }  
  282.     }  
  283.     /** 
  284.      * construct the headTable of the format <count, itemName> descended. 
  285.      * @param transactions  
  286.      * @return headTable 
  287.      */  
  288.     public static JavaRDD<TNode> bulidHeadTable(JavaPairRDD<String, Integer> transactions)  
  289.     {  
  290.         JavaRDD<TNode> headtable = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<String,Integer>, String, Integer>()  
  291.         {  
  292.             private static final long serialVersionUID = -3654849959547730063L;  
  293.   
  294.             public Iterable<Tuple2<String, Integer>> call(Tuple2<String, Integer> arg0)  
  295.                     throws Exception  
  296.             {  
  297.                 List<Tuple2<String, Integer>> t2list = new ArrayList<Tuple2<String, Integer>>();   
  298.                 String[] items = arg0._1.split(SEPARATOR);  
  299.                 int count = arg0._2;  
  300.                 for(String item : items)  
  301.                 {  
  302.                     t2list.add(new Tuple2<String, Integer>(item, count));  
  303.                 }  
  304.                 return t2list;  
  305.             }  
  306.         })  
  307.         //combine the same items.  
  308.         .reduceByKey(new Function2<Integer, Integer, Integer>()  
  309.         {  
  310.             private static final long serialVersionUID = 629605042999702574L;  
  311.   
  312.             public Integer call(Integer arg0, Integer arg1) throws Exception  
  313.             {  
  314.                 return arg0 + arg1;  
  315.             }  
  316.         })  
  317.         //convert <ietm,integer> to <integr,item> format.  
  318.         .mapToPair(new PairFunction<Tuple2<String,Integer>, Integer, String>()  
  319.         {  
  320.             private static final long serialVersionUID = -7017909569876993192L;  
  321.   
  322.             public Tuple2<Integer, String> call(Tuple2<String, Integer> t)  
  323.                     throws Exception  
  324.             {  
  325.                 return new Tuple2<Integer, String>(t._2, t._1);  
  326.             }  
  327.         })  
  328.         //filter out items which satisfies the minimum support_degree.  
  329.         .filter(new Function<Tuple2<Integer, String>, Boolean>()  
  330.         {  
  331.             private static final long serialVersionUID = -3643383589739281939L;  
  332.   
  333.             public Boolean call(Tuple2<Integer, String> v1) throws Exception  
  334.             {  
  335.                 return v1._1 >= SUPPORT_DEGREE;  
  336.             }  
  337.         })  
  338.         //sort items in descent.  
  339.         .sortByKey(false)  
  340.         //convert transactions to TNode.  
  341.         .map(new Function<Tuple2<Integer,String>, TNode>()  
  342.         {  
  343.             private static final long serialVersionUID = 16629827688034851L;  
  344.   
  345.             public TNode call(Tuple2<Integer, String> v1) throws Exception  
  346.             {  
  347.                 return new TNode(v1._2, v1._1);  
  348.             }  
  349.         });  
  350.         return headtable;  
  351.     }  
  352. }  
  353. </span>  
<span style="font-size:14px;">package org.min.fpgrowth;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;

import scala.Tuple2;
/**
 * @author ShiMin
 * @date   2015/10/19
 * @description FPGrowth algorithm runs on spark in java.
 */
public class FPTree
{
    public static int SUPPORT_DEGREE = 4;//the support of FPGrowth algorithm
    public static String SEPARATOR = " ";//line separator

    public static void main(String[] args)
    {
        Logger.getLogger("org.apache.spark").setLevel(Level.OFF);
        args = new String[]{"hdfs://master:9000/data/input/wordcounts.txt", "hdfs://master:9000/data/output"};
        if(args.length != 2)
        {
            System.err.println("USage:<Datapath> <Output>");
            System.exit(1);
        }

        SparkConf sparkConf = new SparkConf().setAppName("frequent parttern growth").setMaster("local[4]");
        JavaSparkContext ctx = new JavaSparkContext(sparkConf);

        //load the transactions data.
        JavaRDD<String> lines = ctx.textFile(args[0], 1)
        //remove the ID of transaction.
        .map(new Function<String, String>()
        {
            private static final long serialVersionUID = -692074104680411557L;

            public String call(String arg0) throws Exception
            {
                return arg0.substring(arg0.indexOf(" ") + 1).trim();
            }
        });

        JavaPairRDD<String, Integer> transactions = constructTransactions(lines);
        //run FPGrowth algorithm
        FPGrowth(transactions, null, ctx);
        //close sparkContext
        ctx.close();
    }

    public static JavaPairRDD<String, Integer> constructTransactions(JavaRDD<String> lines)
    {
        JavaPairRDD<String, Integer> transactions = lines
                //convert lines to <key,value>(or <line,1>) pairs.
                .mapToPair(new PairFunction<String, String, Integer>()
                {
                    private static final long serialVersionUID = 5915574436834522347L;

                    public Tuple2<String, Integer> call(String arg0) throws Exception
                    {
                        return new Tuple2<String, Integer>(arg0, 1);
                    }
                })
                //combine the same translations.
                .reduceByKey(new Function2<Integer, Integer, Integer>()
                {
                    private static final long serialVersionUID = -8075378913994858824L;

                    public Integer call(Integer arg0, Integer arg1) throws Exception
                    {
                        return arg0 + arg1;
                    }
                });
        return transactions;
    }
    /**
     * @param transactions 
     * @param postPattern  
     */
    public static void FPGrowth(JavaPairRDD<String, Integer> transactions, final List<String> postPattern, JavaSparkContext ctx)
    {
        //construct headTable
        JavaRDD<TNode> headTable = bulidHeadTable(transactions);
        List<TNode> headlist = headTable.collect();
        //construct FPTree
        TNode tree = bulidFPTree(headlist, transactions);
        //when the FPTree is empty, then exit the excursion
        if(tree.getChildren() == null || tree.getChildren().size() == 0)
        {
            return;
        }
        //output the frequent itemSet
        if(postPattern != null)
        {
            for(TNode head : headlist)
            {
                System.out.print(head.getCount() + " " + head.getItemName());
                for(String item : postPattern)
                {
                    System.out.print(" " + item);
                }
                System.out.println();
            }
//          headTable.foreach(new VoidFunction<TNode>()
//          {
//              public void call(TNode head) throws Exception
//              {
//                  System.out.println(head.getCount() + " " + head.getItemName());
//                  for(String item : postPattern)
//                  {
//                      System.out.println(" " + item);
//                  }
//              }
//          });
        }
        //traverse each item of headTable
        for(TNode head : headlist)
        {
            List<String> newPostPattern = new ArrayList<String>();
            newPostPattern.add(head.getItemName());
            if(postPattern != null)
            {
                newPostPattern.addAll(postPattern);
            }
            //create new transactions
            List<String> newTransactionsList = new ArrayList<String>();
            TNode nextNode = head.getNext();
            while(nextNode != null)
            {
                int count = head.getCount();
                TNode parent = nextNode.getParent();
                String tlines = "";
                while(parent.getItemName() != null)
                {
                    tlines += parent.getItemName() + " ";
                    parent = parent.getParent();
                }
                while((count--) > 0 && !tlines.equals(""))
                {
                    newTransactionsList.add(tlines);
                }
                nextNode = nextNode.getNext();
            }
            JavaPairRDD<String, Integer> newTransactions = constructTransactions(ctx.parallelize(newTransactionsList));
            FPGrowth(newTransactions, newPostPattern, ctx);
        }
    }

    /**
     * construct FPTree
     * @return the root of FPTree
     */
    public static TNode bulidFPTree(List<TNode> headTable, JavaPairRDD<String, Integer> transactions)
    {
        //create the root node of FPTree
        final TNode rootNode = new TNode();

        final List<TNode> headItems = headTable;
        //convert to transactions which ordered by count DESC and items satisfy the minimum support_degree 
        JavaPairRDD<LinkedList<String>, Integer> transactionsDesc = transactions.mapToPair(new PairFunction<Tuple2<String,Integer>, LinkedList<String>, Integer>()
        {
            private static final long serialVersionUID = 4787405828748201473L;

            public Tuple2<LinkedList<String>, Integer> call(Tuple2<String, Integer> t)
                    throws Exception
            {
                LinkedList<String> descItems = new LinkedList<String>();
                List<String> items = Arrays.asList(t._1.split(SEPARATOR));
                for(TNode node : headItems)
                {
                    String headName = node.getItemName();
                    if(items.contains(headName))
                    {
                        descItems.add(headName);
                    }
                }
                return new Tuple2<LinkedList<String>, Integer>(descItems, t._2);
            }
        })
        .filter(new Function<Tuple2<LinkedList<String>,Integer>, Boolean>()
        {
            private static final long serialVersionUID = -8157084572151575538L;

            public Boolean call(Tuple2<LinkedList<String>, Integer> v1) throws Exception
            {
                return v1._1.size() > 0;
            }
        });
        List<Tuple2<LinkedList<String>, Integer>> lines = transactionsDesc.collect();
        //add each transaction to FPTree
        for(Tuple2<LinkedList<String>, Integer> t : lines)
        {
            LinkedList<String> itemsDesc = t._1;//items to be added to FPTree
            int count = t._2;//how many times itemsDesc to be added to FPTree
            //find out the root node which add List<String> as subtree
            TNode subtreeRoot = rootNode;
            if(subtreeRoot.getChildren().size() != 0)
            {
                TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());
                while(!itemsDesc.isEmpty() && tempNode != null)
                {
                    tempNode.countIncrement(count);
                    subtreeRoot = tempNode;
                    itemsDesc.poll();
                    tempNode = subtreeRoot.findChildren(itemsDesc.peek());
                }
            }
            //add the left items of List<String> to FPTree
            addSubTree(headItems, subtreeRoot, itemsDesc, count);
        }

//      transactionsDesc.foreach(new VoidFunction<Tuple2<LinkedList<String>,Integer>>()
//      {
//          private static final long serialVersionUID = 8054620367976985036L;
//
//          public void call(Tuple2<LinkedList<String>, Integer> t) throws Exception
//          {
//              LinkedList<String> itemsDesc = t._1;//items to be added to FPTree
//              int count = t._2;//how many times itemsDesc to be added to FPTree
//              //find out the root node which add List<String> as subtree
//              TNode subtreeRoot = rootNode;
//              if(subtreeRoot.getChildren().size() != 0)
//              {
//                  TNode tempNode = subtreeRoot.findChildren(itemsDesc.peek());
//                  while(!itemsDesc.isEmpty() && tempNode != null)
//                  {
//                      tempNode.countIncrement(count * 2);
//                      subtreeRoot = tempNode;
//                      itemsDesc.poll();
//                      tempNode = subtreeRoot.findChildren(itemsDesc.peek());
//                  }
//              }
//              //add the left items of List<String> to FPTree
//              addSubTree(headItems, subtreeRoot, itemsDesc, count);
//          }
//      });
        return rootNode;
    }
    /**
     * 
     * @param headTable
     * @param subtreeRoot
     * @param itemsDesc
     * @param count
     */
    public static void addSubTree(List<TNode> headItems, TNode subtreeRoot, LinkedList<String> itemsDesc, int count)
    {
        if(itemsDesc.size() > 0)
        {
            final TNode thisNode = new TNode(itemsDesc.pop(), count);//construct a new node
            subtreeRoot.getChildren().add(thisNode);
            thisNode.setParent(subtreeRoot);
            //add thisNode to the relevant headTable node list
            for(TNode t : headItems)
            {
                if(t.getItemName().equals(thisNode.getItemName()))
                {
                    TNode lastNode = t;
                    while(lastNode.getNext() != null)
                    {
                        lastNode = lastNode.getNext();
                    }
                    lastNode.setNext(thisNode);
                }
            }
            subtreeRoot = thisNode;//update thisNode as the current subtreeRoot
            //add the left items in itemsDesc recursively
            addSubTree(headItems, subtreeRoot, itemsDesc, count);
        }
    }
    /**
     * construct the headTable of the format <count, itemName> descended.
     * @param transactions 
     * @return headTable
     */
    public static JavaRDD<TNode> bulidHeadTable(JavaPairRDD<String, Integer> transactions)
    {
        JavaRDD<TNode> headtable = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<String,Integer>, String, Integer>()
        {
            private static final long serialVersionUID = -3654849959547730063L;

            public Iterable<Tuple2<String, Integer>> call(Tuple2<String, Integer> arg0)
                    throws Exception
            {
                List<Tuple2<String, Integer>> t2list = new ArrayList<Tuple2<String, Integer>>(); 
                String[] items = arg0._1.split(SEPARATOR);
                int count = arg0._2;
                for(String item : items)
                {
                    t2list.add(new Tuple2<String, Integer>(item, count));
                }
                return t2list;
            }
        })
        //combine the same items.
        .reduceByKey(new Function2<Integer, Integer, Integer>()
        {
            private static final long serialVersionUID = 629605042999702574L;

            public Integer call(Integer arg0, Integer arg1) throws Exception
            {
                return arg0 + arg1;
            }
        })
        //convert <ietm,integer> to <integr,item> format.
        .mapToPair(new PairFunction<Tuple2<String,Integer>, Integer, String>()
        {
            private static final long serialVersionUID = -7017909569876993192L;

            public Tuple2<Integer, String> call(Tuple2<String, Integer> t)
                    throws Exception
            {
                return new Tuple2<Integer, String>(t._2, t._1);
            }
        })
        //filter out items which satisfies the minimum support_degree.
        .filter(new Function<Tuple2<Integer, String>, Boolean>()
        {
            private static final long serialVersionUID = -3643383589739281939L;

            public Boolean call(Tuple2<Integer, String> v1) throws Exception
            {
                return v1._1 >= SUPPORT_DEGREE;
            }
        })
        //sort items in descent.
        .sortByKey(false)
        //convert transactions to TNode.
        .map(new Function<Tuple2<Integer,String>, TNode>()
        {
            private static final long serialVersionUID = 16629827688034851L;

            public TNode call(Tuple2<Integer, String> v1) throws Exception
            {
                return new TNode(v1._2, v1._1);
            }
        });
        return headtable;
    }
}
</span>

4 运行结果






评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值