package cn.spark.study.core;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import com.google.common.base.Splitter;
import scala.Tuple2;
public class KNN {
public static void main(String[] args){
SparkConf conf = new SparkConf().setAppName("SparkMarkov");
JavaSparkContext jsc = new JavaSparkContext(conf);
if(args.length < 1){
System.out.println("err");
System.exit(1);
}
JavaRDD<String> R = jsc.textFile(args[0],1);
JavaRDD<String> S = jsc.textFile(args[1],1);
final Broadcast<Integer> broadcastD = jsc.broadcast(Integer.parseInt(args[2]));
final Broadcast<Integer> broadcastK = jsc.broadcast(Integer.parseInt(args[3]));
//计算笛卡儿积
JavaPairRDD<String,String> cart = R.cartesian(S);
/**
* 找出R中的r与S中的s之间的距离
*
*/
JavaPairRDD<String,Tuple2<Double,String>> knnMapped = cart.mapToPair(
new PairFunction<Tuple2<String,String>,String,Tuple2<Double,String>>(){
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Tuple2<Double, String>> call(Tuple2<String, String> t) throws Exception {
String rRecord = t._1;
String sRecord = t._2;
String[] rTokens = rRecord.split(";");
String rRecordID = rTokens[0];
String r = rTokens[1];
String[] sTokens = sRecord.split(";");
String sClassificationID = sTokens[1];
String s = sTokens[2];
Integer d = broadcastD.value();
double distance = calculateDistance(r,s,d);
String K = rRecordID;
Tuple2<Double, String> v = new Tuple2<Double, String>(distance,sClassificationID);
return new Tuple2<String, Tuple2<Double, String>>(K,v);
}
});
List<Tuple2<String, Tuple2<Double, String>>> debug1 = knnMapped.collect();
for(Tuple2<String, Tuple2<Double, String>> s : debug1){
System.out.println("debug2 key="+s._1+" vlaue="+s._2);
}
JavaPairRDD<String,Iterable<Tuple2<Double,String>>> knnGrouped = knnMapped.groupByKey();
List<Tuple2<String, Iterable<Tuple2<Double, String>>>> debug2 = knnGrouped.collect();
for(Tuple2<String, Iterable<Tuple2<Double, String>>> s : debug2){
System.out.println("debug1 key="+s._1+" vlaue="+s._2);
}
JavaPairRDD<String,String> knnOutput = knnGrouped.mapValues(new Function<Iterable<Tuple2<Double,String>>,String>(){
private static final long serialVersionUID = 1L;
@Override
public String call(Iterable<Tuple2<Double, String>> v1) throws Exception {
Integer k = broadcastK.value();
SortedMap<Double,String> nerestK = findNearestK(v1,k);
Map<String,Integer> majority = buildClassificationCount(nerestK);
String selectedClassificationID = classifyByMajority(majority);
return selectedClassificationID;
}
});
List<Tuple2<String, String>> debug3 = knnOutput.collect();
for(Tuple2<String, String> s : debug3){
System.out.println("debug3 key="+s._1+" vlaue="+s._2);
}
}
/**
*
* @param str 是一个逗号或者分号分隔的double值的列表
* @param delimiter 是一个分隔符,如“,” “;”
* @return 一个数据集记录的所有属性
*/
static List<Double> splitOnToListOfDouble(String str,String delimiter){
Splitter splitter = Splitter.on(delimiter).trimResults();
Iterable<String> tokens = splitter.split(str);
if(tokens == null){
return null;
}
List<Double> list = new ArrayList<Double>();
for(String token : tokens){
double data = Double.parseDouble(token);
list.add(data);
}
return list;
}
/**
*
* @param rAsString = "r.1,r.2 ,...,r.n "
* @param sAsString = "s.1,s.2 ,...,s.n "
* @param d是R和S的维数
* @return
*/
static double calculateDistance(String rAsString,String sAsString,int d){
List<Double> r = splitOnToListOfDouble(rAsString,",");
List<Double> s = splitOnToListOfDouble(sAsString,",");
//维数
if(r.size() != d){
return Double.NaN;
}
if(s.size() != d){
return Double.NaN;
}
double sum = 0;
for(int i = 0;i < d; i++){
double difference = r.get(i) - s.get(i);
sum += difference * difference;
}
return Math.sqrt(sum);
}
static SortedMap<Double,String> findNearestK(Iterable<Tuple2<Double,String>> neighbors,int k){
SortedMap<Double,String> nearestK = new TreeMap<Double,String>();
for(Tuple2<Double,String> neighbor : neighbors){
Double distance = neighbor._1;
String classificationID = neighbor._2;
nearestK.put(distance, classificationID);
if(nearestK.size() > k){
nearestK.remove(nearestK.lastKey());
}
}
return nearestK;
}
static Map<String,Integer> buildClassificationCount(Map<Double,String> nearestK){
Map<String,Integer> majority = new HashMap<String,Integer>();
for(Map.Entry<Double, String> entry : nearestK.entrySet()){
String classificationID = entry.getValue();
Integer count = majority.get(classificationID);
if(count == null){
majority.put(classificationID, 1);
}
else{
majority.put(classificationID, count + 1);
}
}
return majority;
}
static String classifyByMajority (Map<String,Integer> majority){
int votes = 0;
String selectedClassification = null;
for(Map.Entry<String, Integer> entry : majority.entrySet()){
if(selectedClassification == null){
selectedClassification = entry.getKey();
votes = entry.getValue();
}
else{
int count = entry.getValue();
if(count > votes){
selectedClassification = entry.getKey();
votes = count;
}
}
}
return selectedClassification;
}
}
测试数据:
rKNN.txt
1000;2.0,3.0
1001;10.1,3.2
1003;2.7,2.7
1004;5.0,5.0
1005;13.1,2.2
1006;12.7,12.7
sKNN.txt
100;c1;1.0,1.0
101;c1;1.1,1.2
102;c1;1.2,1.0
103;c1;1.6,1.5
104;c1;1.3,1.7
105;c1;2.0,2.1
106;c1;2.0,2.2
107;c1;2.3,2.3
208;c2;9.0,9.0
209;c2;9.1,9.2
210;c2;9.2,9.0
211;c2;10.6,10.5
212;c2;10.3,10.7
213;c2;9.6,9.1
214;c2;9.4,10.4
215;c2;10.3,10.3
300;c3;10.0,1.0
301;c3;10.1,1.2
302;c3;10.2,1.0
303;c3;10.6,1.5
304;c3;10.3,1.7
305;c3;10.0,2.1
306;c3;10.0,2.2
307;c3;10.3,2.3
脚本:
k=4
d=2
rFile=hdfs://spark01:9000/KNN/rKNN.txt
sFile=hdfs://spark01:9000/KNN/sKNN.txt
/usr/local/spark1.5/bin/spark-submit \
--class cn.spark.study.core.KNN \
--num-executors 3 \
--driver-memory 100m \
--executor-memory 100m \
--executor-cores 3 \
/usr/local/spark-text/java/KNN/knn.jar $rFile $sFile $d $k
运行结果:
debug3 key=1005 vlaue=c3
debug3 key=1001 vlaue=c3
debug3 key=1000 vlaue=c1
debug3 key=1004 vlaue=c1
debug3 key=1006 vlaue=c2
debug3 key=1003 vlaue=c1