基于Spark on Yarn的apriori算法java实现

一  前言

处理一个大数据集,找出其中的强关联规则,本文基于spark使用java语言实现了apriori算法,算法已经通过测试,后边附带一个测试实验及运行结果。

二  apriori算法描述

apriori是一种经典的数据挖掘算法,可以挖掘出数据库中哪些物品经常一起出现,满足最小支持度和最小置信度的的规则为强关联规则。因此,算法需要找出所有的强关联规则,从而为实际提供决策或者预测未来的结果。apriori算法使用逐层搜索的迭代思想,第k频繁项集用于找出第(k+1)频繁项集,依次类推,直到找出所有的频繁项集。最后从已经找出的这些频繁项集中进一步找出所有强关联规则。本文实现的算法主要完成找出所有的频繁项集,这一步也是apriori算法最重要的。

三  算法实现

package org.min.apriori;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;


import scala.Tuple2;
import scala.collection.mutable.ArrayBuffer;


/**
 * 
 * @author ShiMin
 * @date   2015/10/13
 * @description APriori algorithm runs on spark in java. 
 *
 */
public class FrequentItemset
{
public static int SUPPORT_DEGREE = 4;//the support of APriori algorithm
public static int TRANSACTION_NUM = 25;//the number of transaction
public static String SEPARATOR = " ";//line separator
public static int NOFITEMS = 4;//the number of items in itemSet


@SuppressWarnings("serial")
public static void main(String[] args)
{
Logger.getLogger("org.apache.spark").setLevel(Level.OFF);
args = new String[]{"hdfs://master:9000/data/input/wordcounts.txt", "hdfs://master:9000/data/output"};

if(args.length != 2)
{
System.err.println("USage:<Datapath> <Output>");
System.exit(1);
}

SparkConf sparkConf = new SparkConf().setAppName("apriori algorithm").setMaster("local[4]");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);


JavaRDD<String> lines = ctx.textFile(args[0], 1); //textFile(path: String, minPartitions: Int)

//remove the ID of transaction.
JavaPairRDD<String,Integer> ones = lines.map(new Function<String, String>()
{
public String call(String v1) throws Exception
{
return v1.substring(v1.indexOf(" ") + 1).trim();
}
})
//convert lines to <key,value>(or <line,1>) pairs.
.mapToPair(new PairFunction<String, String, Integer>()
{
public Tuple2<String, Integer> call(String t) throws Exception
{
return new Tuple2<String, Integer>(t, 1);
}
})
//combine the same translations.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2 ;
}
});

//convert <line,count> pairs to <List<String>,count> pairs form.
JavaPairRDD<List<String>, Integer> transactions = ones.mapToPair(new PairFunction<Tuple2<String,Integer>, List<String>, Integer>()
{
public Tuple2<List<String>, Integer> call(Tuple2<String, Integer> t)
throws Exception
{
String[] items = t._1.split(SEPARATOR);
List<String> itemlist = Arrays.asList(items);;
return new Tuple2<List<String>, Integer>(itemlist, t._2);
}
})
//cache the transaction in memory.
.cache();

//count the 1 frequent itemSet which satisfies the minimum support_degree.
JavaPairRDD<String, Integer> onefi = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>,Integer>, String, Integer>()
{
public Iterable<Tuple2<String, Integer>> call(Tuple2<List<String>, Integer> t)
throws Exception
{
List<Tuple2<String, Integer>> t2list = new ArrayList<Tuple2<String, Integer>>();
for(String item : t._1)
{
t2list.add(new Tuple2<String, Integer>(item, t._2));
}
return t2list;
}
})
//combine the same item.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2 ;
}
})
//filter out the satisfactory item.
.filter(new Function<Tuple2<String,Integer>, Boolean>()
{
public Boolean call(Tuple2<String, Integer> v1) throws Exception
{
return v1._2 >= SUPPORT_DEGREE;
}
})
//cache the 1 frequent itemSet in memory.
.cache();

//compute the support rate of each item.
onefi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1 + ":" + (double)v1._2 / TRANSACTION_NUM;
}
})
//save the 1 frequent itemSet to the result_1.txt.
.saveAsTextFile(args[1] + "/result_1");

//count the k frequent itemSet which satisfies the minimun support_degree.
JavaPairRDD<String, Integer> kfi = onefi;
for(int k = 2; k <= NOFITEMS; ++k)
{
List<String> candiatefi = getCandiateKFI(kfi, k);
System.out.println(k + " = " + candiatefi);
JavaRDD<String> firdd = ctx.parallelize(candiatefi);
final Broadcast<JavaRDD<String>> bccFI = ctx.broadcast(firdd);

kfi = transactions.flatMapToPair(new PairFlatMapFunction<Tuple2<List<String>,Integer>, String, Integer>()
{
private static final long serialVersionUID = 3107941823066446782L;


public Iterable<Tuple2<String, Integer>> call(
final Tuple2<List<String>, Integer> line) throws Exception
{
List<Tuple2<String, Integer>> t2list = bccFI.value().flatMapToPair(new PairFlatMapFunction<String, String, Integer>()
{


public Iterable<Tuple2<String, Integer>> call(String t)
throws Exception
{
List<Tuple2<String, Integer>> lineitemlist = new ArrayList<Tuple2<String, Integer>>();
String[] items = t.split(",");
if(line._1.containsAll(Arrays.asList(items)))
{
lineitemlist.add(new Tuple2<String, Integer>(t, line._2));
System.out.println("line=" + line + " items=" + Arrays.asList(items) + " flag=" + (line._1.containsAll(Arrays.asList(items))));
}
return lineitemlist;
}
}).collect();

System.out.println("t2list" + "=" + t2list + " line=" + line);
return t2list;
}
})
//combine the same item.
.reduceByKey(new Function2<Integer, Integer, Integer>()
{
public Integer call(Integer v1, Integer v2) throws Exception
{
return v1 + v2;
}
})
//filter out the satisfactory item.
.filter(new Function<Tuple2<String,Integer>, Boolean>()
{
public Boolean call(Tuple2<String, Integer> v1) throws Exception
{
System.out.println(v1._1 + ":" + v1._2);

return v1._2 >= SUPPORT_DEGREE;
}
})
//cache the k frequent itemSet in memory.
.cache();

//compute the support rate of each item.
kfi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1 + ":" + (double)v1._2 / TRANSACTION_NUM;
}
})
//save the k frequent itemSet to the result_k.txt.
.saveAsTextFile(args[1] + "/result_" + k);
}

// onefi.foreach(new VoidFunction<Tuple2<String, Integer>>()
// {
// public void call(Tuple2<String, Integer> t) throws Exception
// {
// System.out.println(t);
// }
// });

}

public static List<String> getCandiateKFI(JavaPairRDD<String, Integer> kfi, int k)
{
List<String> candiateItemSet = new ArrayList<String>();

//extract the items,save them in list.
List<String> itemlist = kfi.map(new Function<Tuple2<String,Integer>, String>()
{
public String call(Tuple2<String, Integer> v1) throws Exception
{
return v1._1;
}
}).collect();

for(int i = 0; i < itemlist.size() - 1; i++)
{
for(int j = i + 1; j < itemlist.size(); j++)
{
String tmpItem = "";

if(2 == k)
{
tmpItem = itemlist.get(i) + "," + itemlist.get(j);
tmpItem = sortItems(tmpItem);
}
else
{
String s1 = itemlist.get(i);
String s2 = itemlist.get(j);
if(s1.substring(0, s1.lastIndexOf(',')).equals(s2.substring(0, s2.lastIndexOf(','))))
{
tmpItem = s1 + s2.substring(s2.lastIndexOf(','));
tmpItem = sortItems(tmpItem);
}
}

//filter the item which has infrequent subItem.
boolean hasInfrequentSubItem = false;
if(!"".equals(tmpItem))
{
String[] items = tmpItem.split(",");
for(int m = 0; m < items.length; m++)
{
String subItem = "";
for(int n = 0; n < items.length; n++)
{
if(m != n)
{
subItem += items[n] + ",";
}
}
subItem = subItem.substring(0, subItem.lastIndexOf(','));

if(!itemlist.contains(subItem))
{
hasInfrequentSubItem = true;
break;
}
}
}
else
{
hasInfrequentSubItem = true;
}

if(!hasInfrequentSubItem)
{
candiateItemSet.add(tmpItem);
}
}
}
return candiateItemSet;
}

public static String sortItems(String itemStr)
{
String result = "";
String[] items = itemStr.split(",");
Arrays.sort(items);
for(String item : items)
{
result += item + ",";
}
return result.substring(0, result.lastIndexOf(','));
}



}


四  实验结果

程序中只设定挖掘出最高4项集,生成的结果保存在4个文件中,结果如下:


每个文件内容如下:

发布了34 篇原创文章 · 获赞 10 · 访问量 4万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 编程工作室 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览