MapReduce实现KNN

不正之处,欢迎指正。

        KNN算法称为K近邻分类算法,是最简单的分类器,KNN算法从训练集中找到和测试数据距离最近的K个记录,然后根据这K个记录的标记来决定测试实例的最终标记。MapReduce作为一种大数据环境下的计算模型,在分布式计算中具有其独特的优势,本文主要在hadoop框架下面实现KNN算法。

        实验环境:centos6.5+hadoop2.2.0

实验步骤:

         MapReduce的关键之处在于实现用户自定义的map和Reduce函数,在本实例中,我们在mapper类中的clean函数中首先读取所有的训练数据,用一个List来进行存储。在map阶段,逐行读取每一个测试实例,计算测试实例和训练数据之间的距离,找到最近的k个距离所对应的标记。在Reduce阶段,通过统计map阶段的标记信息,找到出现次数最多的标记就是最终的测试用例标记。实验代码如下:

package org.apache.hadoop.knn;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

public class Knn {

	public static class KnnMap extends Mapper<LongWritable, Text, Text, Text> {
		public ArrayList<Instance> train = new ArrayList<Instance>();       //存储训练集
		public int k = 5;

		@Override
		//读取训练集
		protected void setup(
				Mapper<LongWritable, Text, Text, Text>.Context context)
				throws IOException, InterruptedException {
			// TODO Auto-generated method stub
			// super.setup(context);
			FileSystem fs = null;
			try {
				fs = FileSystem.get(new URI("hdfs://192.168.1.119:9000"), new Configuration());
			} catch (Exception e) {
				
			}
			FSDataInputStream fi = fs.open(new Path(
					"hdfs://192.168.1.119:9000/data/traindata.txt"));
			BufferedReader bf = new BufferedReader(new InputStreamReader(fi));
			String line = bf.readLine();
			while (line != null) {
				Instance sample = new Instance(line);
				train.add(sample);
				line = bf.readLine();
			}
		}

		@Override
		protected void map(LongWritable key, Text value, Context context)
				throws IOException, InterruptedException {
			// TODO Auto-generated method stub
			// super.map(key, value, context);
			ArrayList<Double> distance = new ArrayList<Double>(k);
			ArrayList<String> trainlabel = new ArrayList<String>(k);
			for (int i = 0; i < k; i++) {
				distance.add(Double.MAX_VALUE);
				trainlabel.add(String.valueOf("-1.0"));
			}
			Instance test = new Instance(value.toString());
			for (int i = 0; i < train.size(); i++) {
				double dis = Distance(train.get(i).getFeatures(),
						test.getFeatures());
				for (int j = 0; j < k; j++) {
					if (dis < (Double) distance.get(j)) {
						distance.set(j, dis);
						trainlabel.set(j, train.get(i).getLabel() + "");
						break;
					}
				}
			}
			for (int i = 0; i < k; i++) {
				context.write(new Text(value.toString()),
						new Text(trainlabel.get(i) + ""));
			}
		}

		private double Distance(double[] a, double[] b) {
			// TODO Auto-generated method stub
			double sum = 0.0;
			for (int i = 0; i < a.length; i++) {
				sum += Math.pow(a[i] - b[i], 2);
			}
			return Math.sqrt(sum);
		}

	}

	public static class KnnReducer extends
			Reducer<Text, Text, Text, NullWritable> {
		@Override
		protected void reduce(Text k, Iterable<Text> values, Context context)
				throws IOException, InterruptedException {
			// TODO Auto-generated method stub
			// super.reduce(arg0, arg1, arg2);
			ArrayList<String> l = new ArrayList<String>();
			for (Text t : values) {
				l.add(t.toString());
			}
			String predict = Predict(l);
			context.write(new Text(k.toString() + "\t" + predict),
					NullWritable.get());

		}

		private String Predict(ArrayList<String> arr) {
			// TODO Auto-generated method stub
			HashMap<String, Double> tmp = new HashMap<String, Double>();
			for (int i = 0; i < arr.size(); i++) {
				if (tmp.containsKey(arr.get(i))) {
					double frequence = tmp.get(arr.get(i)) + 1;
					tmp.remove(arr.get(i));
					tmp.put((String) arr.get(i), frequence);
				} else
					tmp.put((String) arr.get(i), new Double(1));
			}
			Set<String> s = tmp.keySet();

			Iterator it = s.iterator();
			double lablemax = Double.MIN_VALUE;
			String predictlable = null;
			while (it.hasNext()) {
				String key = (String) it.next();
				Double lablenum = tmp.get(key);
				if (lablenum > lablemax) {
					lablemax = lablenum;
					predictlable = key;
				}
			}
			return predictlable;
		}
	}

	public static void main(String[] args) throws IOException,
			ClassNotFoundException, InterruptedException {
		FileSystem fs = FileSystem.get(new Configuration());

		Job job = new Job(new Configuration());
		job.setJarByClass(Knn.class);

		FileInputFormat.setInputPaths(job, new Path(args[0]));
		job.setMapperClass(KnnMap.class);
		job.setMapOutputKeyClass(Text.class);
		job.setMapOutputValueClass(Text.class);

		FileOutputFormat.setOutputPath(job, new Path(args[1]));
		job.setReducerClass(KnnReducer.class);
		job.setOutputKeyClass(Text.class);
		job.setOutputValueClass(NullWritable.class);

		job.waitForCompletion(true);

	}

}



  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值