KNN 假如有一群已知分类的点集: //S.txt 100;c1;1.0,1.0 101;c1;1.1,1.2 102;c1;1.2,1.0 103;c1;1.6,1.5 104;c1;1.3,1.7 105;c1;2.0,2.1 106;c1;2.0,2.2 107;c1;2.3,2.3 208;c2;9.0,9.0 209;c2;9.1,9.2 210;c2;9.2,9.0 211;c2;10.6,10.5 212;c2;10.3,10.7 213;c2;9.6,9.1 214;c2;9.4,10.4 215;c2;10.3,10.3 300;c3;10.0,1.0 301;c3;10.1,1.2 302;c3;10.2,1.0 303;c3;10.6,1.5 304;c3;10.3,1.7 305;c3;1.0,2.1 306;c3;10.0,2.2 307;c3;10.3,2.3 和未知分类的点集 //R.txt 1000;3.0,3.0 1001;10.1,3.2 1003;2.7,2.7 1004;5.0,5.0 1005;13.1,2.2 1006;12.7,12.7 如何为R中的每一个点找到它合适的分类呢? KNN(K邻近)算法: (1)确定K(K的选择取决于具体的数据和项目需求) (2)计算新输入,如【1000;3.0,3.0】与所有训练数据之间的距离(与K一样,距离函数的选择也取决于数据的类型) (3)对距离排序,并根据前K个最小距离确定K个邻近。 (4)搜集这些邻近所属的类别 (5)根据多数投票确定类别 通俗来说 有一群土豪:土豪1,土豪2,土豪3,土豪4... 有一群屌丝,屌丝1,屌丝2,屌丝3,屌丝4... 现在来了一个人,如何判断这个人是屌丝还是土豪呢? 先计算这个人和所有土豪以及所有屌丝的距离(存款、房产等),然后将这些距离按从小到大的距离排列: d1 < d2 < d3 < ... 然后统计这些距离是这个人跟谁比较得来的,例如: d1=distance<person,土豪32>,则给这个人投一土豪票 d2=distance<person,屌丝100>,则给这个人投一屌丝票 d3=distance<person,土豪1>,则给这个人投一土豪票 ... 最后看那一类的票数最多,如果土豪票票数最多,则将这个人分类为土豪。 MapReduce实现思路: map执行之前(setup阶段),将已分类文件S.txt缓存到内存中 map阶段:每次从R中读入一行,再遍历S的每一行,两这行对应的点求距离,并生成(distance,classfication),如(3,土豪),即R这一行与土豪的距离为3, map输出:(rID,(distance,classfication)) reduce输入:(rID,[(distance,classfication),(distance,classfication),(distance,classfication)....]) reduce:由于reduce输入的[(distance,classfication),(distance,classfication),(distance,classfication)....]是无序的 因此需要对[(distance,classfication),(distance,classfication),(distance,classfication)....]进行排序 由于这个集合可能很大以至于内存中无法存放,因此需要另想办法,使这个集合到达reduce输入时就是有序的 关键技术:组合键 二次排序 将map输出由自然键rID改为(rID,distance),通过自定义分区器和分组比较器使map的输出按rID分区分组,同时能按distance排序
package cjknn; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Partitioner; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import edu.umd.cloud9.io.pair.PairOfFloatString; import edu.umd.cloud9.io.pair.PairOfStringFloat; public class CJKNN extends Configured implements Tool { public static final int DIMENS = 2;//点的维数 public static final int K = 6; public static ArrayList<Float> getVectorFromStr(String str) { String[] vectorStr = str.split(","); ArrayList<Float> vector = new ArrayList<Float>(); for(int i=0;i<vectorStr.length && i < DIMENS; i++) { vector.add(Float.valueOf(vectorStr[i])); } return vector; } public static class CJKNNMapper extends Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString> { PairOfStringFloat outputKey = new PairOfStringFloat(); PairOfFloatString outputValue = new PairOfFloatString(); List<String> S = new ArrayList<String>(); @Override protected void setup( Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString>.Context context) throws IOException, InterruptedException { //get S from cache FileReader fr = new FileReader("S"); BufferedReader br = new BufferedReader(fr); String line = null; while((line = br.readLine()) != null) { S.add(line); } fr.close(); br.close(); } @Override protected void map( LongWritable key, Text value, Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString>.Context context) throws IOException, InterruptedException { String rID = value.toString().split(";")[0]; String rVectorStr = value.toString().split(";")[1]; ArrayList<Float> rVector = getVectorFromStr(rVectorStr); for(String s : S) { ArrayList<Float> sVector = getVectorFromStr(s.split(";")[2]); float distance = calculateDistance(rVector,sVector); outputKey.set(rID, distance); outputValue.set(distance, s.split(";")[1]); context.write(outputKey, outputValue); } } private float calculateDistance(ArrayList<Float> rVector, ArrayList<Float> sVector) { double sum = 0.0; for(int i=0;i<rVector.size() && i < DIMENS;i++) { sum += Math.pow((rVector.get(i) - sVector.get(i)), 2); } return (float) Math.sqrt(sum); } } public static class CJGroupingComparator extends WritableComparator { public CJGroupingComparator() { super(PairOfStringFloat.class, true); } @SuppressWarnings("rawtypes") @Override public int compare(WritableComparable wc1, WritableComparable wc2) { PairOfStringFloat pair = (PairOfStringFloat) wc1; PairOfStringFloat pair2 = (PairOfStringFloat) wc2; int result = pair.getLeftElement().compareTo(pair2.getLeftElement()); return -result; } } /*** * 定制分区器 * 分区器会根据映射器的输出键来决定哪个映射器的输出发送到哪个规约器。为此我们需要定义两个插件类 * 首先需要一个定制分区器控制哪个规约器处理哪些键,另外还要定义一个定制比较器对规约器值排序。 * 这个定制分区器可以确保具有相同键(自然键,而不是包含温度值的组合键)的所有数据都发送给同一个规约器。 * 定制比较器会完成排序,保证一旦数据到达规约器,就会按自然键对数据分组。 * @author chenjie * */ public class CJPartitioner extends Partitioner<PairOfStringFloat, Text> { @Override public int getPartition(PairOfStringFloat pair, Text text, int numberOfPartitions) { // make sure that partitions are non-negative return Math.abs(pair.getLeftElement().hashCode() % numberOfPartitions); } } public static class CJKNNReducer extends Reducer<PairOfStringFloat, PairOfFloatString, Text, Text> { @Override protected void reduce( PairOfStringFloat key, Iterable<PairOfFloatString> values, Context context) throws IOException, InterruptedException { System.out.println("key= " + key); System.out.println("values:"); Map<String,Integer> map = new HashMap<String,Integer>(); int count = 0; Iterator<PairOfFloatString> iterator = values.iterator(); while(iterator.hasNext()) { PairOfFloatString value = iterator.next(); System.out.println(value); String sClassificationID = value.getRightElement(); Integer times = map.get(sClassificationID); if (times== null ) { map.put(sClassificationID, 1); } else { map.put(sClassificationID, times+1); } count ++; if(count >= K) break; } int max = 0; String maxSClassificationID = ""; System.out.println("map:"); for(Map.Entry<String, Integer> entry : map.entrySet()) { System.out.println(entry); if(entry.getValue() > max) { max = entry.getValue(); maxSClassificationID = entry.getKey(); } } context.write(new Text(key.getLeftElement()), new Text(maxSClassificationID)); } } public static void main(String[] args) throws Exception { args = new String[2]; args[0] = "/media/chenjie/0009418200012FF3/ubuntu/R.txt"; args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJKNN";; int jobStatus = submitJob(args); System.exit(jobStatus); } public static int submitJob(String[] args) throws Exception { int jobStatus = ToolRunner.run(new CJKNN(), args); return jobStatus; } @SuppressWarnings("deprecation") @Override public int run(String[] args) throws Exception { Configuration conf = getConf(); Job job = new Job(conf); job.setJobName("KNN"); job.setInputFormatClass(TextInputFormat.class); job.setOutputFormatClass(TextOutputFormat.class); job.setOutputKeyClass(PairOfStringFloat.class); job.setOutputValueClass(PairOfFloatString.class); job.setMapperClass(CJKNNMapper.class); job.setReducerClass(CJKNNReducer.class); FileInputFormat.setInputPaths(job, new Path(args[0])); FileOutputFormat.setOutputPath(job, new Path(args[1])); job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/S.txt" + "#S")); job.setPartitionerClass(CJPartitioner.class); job.setGroupingComparatorClass(CJGroupingComparator.class); FileSystem fs = FileSystem.get(conf); Path outPath = new Path(args[1]); if(fs.exists(outPath)) { fs.delete(outPath, true); } boolean status = job.waitForCompletion(true); return status ? 0 : 1; } }
Spark解决方案:
import org.apache.spark.SparkConf import org.apache.spark.SparkContext object KNN { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("kNN").setMaster("local") val sc = new SparkContext(sparkConf) val k = 4// val d = 2//向量维数 val inputDatasetR = "file:///media/chenjie/0009418200012FF3/ubuntu/R.txt" val inputDatasetS = "file:///media/chenjie/0009418200012FF3/ubuntu/S.txt" val output = "file:///media/chenjie/0009418200012FF3/ubuntu/KNN" val broadcastK = sc.broadcast(k) val broadcastD = sc.broadcast(d) val R = sc.textFile(inputDatasetR) /* 1000;3.0,3.0 1001;10.1,3.2 1003;2.7,2.7 1004;5.0,5.0 1005;13.1,2.2 1006;12.7,12.7 */ val S = sc.textFile(inputDatasetS) /* 100;c1;1.0,1.0 101;c1;1.1,1.2 102;c1;1.2,1.0 103;c1;1.6,1.5 104;c1;1.3,1.7 105;c1;2.0,2.1 106;c1;2.0,2.2 107;c1;2.3,2.3 208;c2;9.0,9.0 209;c2;9.1,9.2 210;c2;9.2,9.0 211;c2;10.6,10.5 212;c2;10.3,10.7 213;c2;9.6,9.1 214;c2;9.4,10.4 215;c2;10.3,10.3 300;c3;10.0,1.0 301;c3;10.1,1.2 302;c3;10.2,1.0 303;c3;10.6,1.5 304;c3;10.3,1.7 305;c3;1.0,2.1 306;c3;10.0,2.2 307;c3;10.3,2.3 */ /** * 计算两点间的距离 * * @param rAsString as r1,r2, ..., rd * @param sAsString as s1,s2, ..., sd * @param d 维数 */ def calculateDistance(rAsString: String, sAsString: String, d: Int): Double = { val r = rAsString.split(",").map(_.toDouble) val s = sAsString.split(",").map(_.toDouble) if (r.length != d || s.length != d) Double.NaN else { math.sqrt((r, s).zipped.take(d).map { case (ri, si) => math.pow((ri - si), 2) }.reduce(_ + _)) } } val cart = R cartesian S//笛卡尔积 /* (1000;3.0,3.0,100;c1;1.0,1.0) (1000;3.0,3.0,101;c1;1.1,1.2) (1000;3.0,3.0,102;c1;1.2,1.0) (1000;3.0,3.0,103;c1;1.6,1.5) ... */ val knnMapped = cart.map(cartRecord => {//(1000;3.0,3.0,100;c1;1.0,1.0) val rRecord = cartRecord._1//1000;3.0,3.0 val sRecord = cartRecord._2//100;c1;1.0,1.0 val rTokens = rRecord.split(";")//(1000 3.0,3.0) val rRecordID = rTokens(0)//1000 val r = rTokens(1) // 3.0,3.0 val sTokens = sRecord.split(";")//(100 c1 1.0,1.0) val sClassificationID = sTokens(1)//c1 val s = sTokens(2) // 1.0,1.0 val distance = calculateDistance(r, s, broadcastD.value)//sqrt((3-1)^2+(3-1)^2)=2.8284 (rRecordID, (distance, sClassificationID))//(1000,(2.8284,c1)) }) // note that groupByKey() provides an expensive solution // [you must have enough memory/RAM to hold all values for // a given key -- otherwise you might get OOM error], but // combineByKey() and reduceByKey() will give a better // scale-out performance val knnGrouped = knnMapped.groupByKey() /* (1005,CompactBuffer((12.159358535712318,c1), (12.041594578792296,c1), (11.960351165413163,c1), (11.52128465059344,c1), (11.81058846967415,c1), (11.10045044131093,c1), (11.1,c1), (10.800462953040487,c1), (7.940403012442126,c2), (8.06225774829855,c2), (7.8390050389064045,c2), (8.668333173107735,c2), (8.94930164873215,c2), (7.736924453553879,c2), (8.996110270555825,c2), (8.570297544426332,c2), (3.3241540277189316,c3), (3.1622776601683795,c3), (3.1384709652950433,c3), (2.596150997149434,c3), (2.8442925306655775,c3), (12.100413216084812,c3), (3.0999999999999996,c3), (2.801785145224379,c3))) (1001,CompactBuffer((9.362157870918434,c1), (9.219544457292887,c1), (9.167878707749137,c1), (8.668333173107735,c1), (8.926925562588723,c1), (8.174350127074323,c1), (8.161494961096281,c1), (7.85175139698144,c1), (5.903388857258177,c2), (6.0827625302982185,c2), (5.869412236331676,c2), (7.3171032519706865,c2), (7.502666192761076,c2), (5.921148537234984,c2), (7.23394774656273,c2), (7.102816342831906,c2), (2.202271554554524,c3), (2.0,c3), (2.202271554554524,c3), (1.7720045146669352,c3), (1.513274595042156,c3), (9.166242414424788,c3), (1.004987562112089,c3), (0.9219544457292893,c3))) (1000,CompactBuffer((2.8284271247461903,c1), (2.6172504656604803,c1), (2.6907248094147422,c1), (2.0518284528683193,c1), (2.1400934559032696,c1), (1.345362404707371,c1), (1.2806248474865696,c1), (0.9899494936611668,c1), (8.48528137423857,c2), (8.697700845625812,c2), (8.627861844049196,c2), (10.67754653466797,c2), (10.61037228376083,c2), (8.987213138676527,c2), (9.783659846908007,c2), (10.323759005323595,c2), (7.280109889280518,c3), (7.3246160308919945,c3), (7.472616676907761,c3), (7.746612162745725,c3), (7.414849964766652,c3), (2.1931712199461306,c3), (7.045565981523415,c3), (7.333484846919642,c3))) (1004,CompactBuffer((5.656854249492381,c1), (5.445181356024793,c1), (5.517245689653488,c1), (4.879549159502341,c1), (4.957822102496216,c1), (4.172529209005013,c1), (4.1036569057366385,c1), (3.818376618407357,c1), (5.656854249492381,c2), (5.869412236331675,c2), (5.8,c2), (7.849203781276162,c2), (7.783315488916019,c2), (6.161980201201558,c2), (6.9656299069072,c2), (7.495331880577405,c2), (6.4031242374328485,c3), (6.360031446463138,c3), (6.56048778674269,c3), (6.603786792439623,c3), (6.243396511515186,c3), (4.94064773081425,c3), (5.730619512757761,c3), (5.948108943185221,c3))) (1006,CompactBuffer((16.54629867976521,c1), (16.33431969810803,c1), (16.405486887014355,c1), (15.76863976378432,c1), (15.84171707865028,c1), (15.06154042586614,c1), (14.991330828181999,c1), (14.707821048680186,c1), (5.23259018078045,c2), (5.020956084253276,c2), (5.093132631298737,c2), (3.0413812651491092,c2), (3.124099870362661,c2), (4.75078940808788,c2), (4.0224370722237515,c2), (3.3941125496954263,c2), (12.0074976577137,c3), (11.79025020938911,c3), (11.964113005150026,c3), (11.395174417269795,c3), (11.258774356030056,c3), (15.787653403846944,c3), (10.84158659975559,c3), (10.673331251301065,c3))) (1003,CompactBuffer((2.4041630560342617,c1), (2.193171219946131,c1), (2.267156809750927,c1), (1.6278820596099708,c1), (1.7204650534085255,c1), (0.9219544457292889,c1), (0.8602325267042628,c1), (0.5656854249492386,c1), (8.909545442950499,c2), (9.121951545584968,c2), (9.052071586106685,c2), (11.101801655587257,c2), (11.034491379306976,c2), (9.411163583744573,c2), (10.206860437960342,c2), (10.748023074035522,c2), (7.495331880577404,c3), (7.5504966724050675,c3), (7.690253571892151,c3), (7.99061950038919,c3), (7.66550715869472,c3), (1.8027756377319948,c3), (7.3171032519706865,c3), (7.610519036176179,c3))) */ val knnOutput = knnGrouped.mapValues(itr => { //itr.toList.sortBy(_._1).foreach(println) /* (2.596150997149434,c3) (2.801785145224379,c3) (2.8442925306655775,c3) (3.0999999999999996,c3) (3.1384709652950433,c3) (3.1622776601683795,c3) (3.3241540277189316,c3) (7.736924453553879,c2) (7.8390050389064045,c2) (7.940403012442126,c2) (8.06225774829855,c2) (8.570297544426332,c2) (8.668333173107735,c2) (8.94930164873215,c2) (8.996110270555825,c2) (10.800462953040487,c1) */ val nearestK = itr.toList.sortBy(_._1).take(broadcastK.value) /* (2.596150997149434,c3) (2.801785145224379,c3) (2.8442925306655775,c3) (3.0999999999999996,c3) */ //nearestK.map(f => (f._2, 1)).foreach(println) /* (c3,1) (c3,1) (c3,1) (c3,1) */ //nearestK.map(f => (f._2, 1)).groupBy(_._1) //(c3,List((c3,1), (c3,1), (c3,1), (c3,1))) val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => { val (stringList, intlist) = list.unzip intlist.sum }) //(c3,4) majority.maxBy(_._2)._1 //c3 }) //(1005,c3) knnOutput.foreach(println) knnOutput.saveAsTextFile(output) sc.stop() } }
输出结果:
1000 c1
1001 c3
1003 c1
1004 c1
1005 c3
1006 c2