KNN Hadoop MapReduce

K Nearest Neighbor算法

K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法。其中的K表示最接近自己的K个数据样本。KNN算法和K-Means算法不同的是,K-Means算法用来聚类,用来判断哪些东西是一个比较相近的类型,而KNN算法是用来做归类的,也就是说,有一个样本空间里的样本分成很几个类型,然后,给定一个待分类的数据,通过计算接近自己最近的K个样本来判断这个待分类数据属于哪个分类。你可以简单的理解为由那离自己最近的K个点来投票决定待分类数据归为哪一类。
根据距离函数计算待分类样本X和每个训练样本的距离(作为相似度),选择与待分类样本距离最小的K个样本作为X的K个最邻近,最后以X的K个最邻近中的大多数所属的类别作为X的类别。KNN可以说是一种最直接的用来分类未知数据的方法。
简单来说,KNN可以看成:有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就开始跟训练数据里的每个点求距离,然后挑出离这个数据最近的K个点,看看这K个点属于什么类型,然后用少数服从多数的原则,给新数据归类。
这里写图片描述
这里写图片描述
从上图中我们可以看到,图中的有两个类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形。而那个绿色的圆形是我们待分类的数据。
如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。
如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。

源码

package org.bigdata.mapreduce.knn;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.bigdata.util.DistanceUtil;
import org.bigdata.util.HadoopCfg;
import org.bigdata.util.HadoopUtil;

public class KNNMapReduce {

    public static final String POINTS = "knn-test.txt";
    public static final int K = 3;
    public static final int TYPES = 3;
    private static final String JOB_NAME = "knn";

    // train-points
    private static List<Point> trans_points = new ArrayList<>();

    public static void initPoints(String pathin, String filename)
            throws IOException {
        List<String> lines = HadoopUtil.lslFile(pathin, filename);
        for (String line : lines) {
            trans_points.add(new Point(line));
        }
    }

    public static class KNNMapper extends
            Mapper<LongWritable, Text, Text, Text> {

        @Override
        protected void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {
            FileSplit fileSplit = (FileSplit) context.getInputSplit();
            String fileName = fileSplit.getPath().getName();
            if (POINTS.equals(fileName)) {
                Point point1 = new Point(value.toString());
                try {
                    for (Point point2 : trans_points) {
                        double dis = DistanceUtil.getEuclideanDisc(
                                point1.getV(), point2.getV());
                        context.write(new Text(point1.toString()), new Text(
                                point2.getType() + ":" + dis));
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

    }

    public static class KNNReducer extends
            Reducer<Text, Text, Text, IntWritable> {

        @Override
        protected void reduce(Text key, Iterable<Text> values, Context context)
                throws IOException, InterruptedException {
            List<DistanceType> list = new ArrayList<>();
            for (Text value : values) {
                list.add(new DistanceType(value.toString()));
            }
            Collections.sort(list);
            int cnt[] = new int[TYPES + 1];
            for (int i = 0, len = cnt.length; i < len; i++) {
                cnt[i] = 0;
            }
            for (int i = 0; i < K; i++) {
                cnt[list.get(i).getType()]++;
            }
            int type = 0;
            int maxx = Integer.MIN_VALUE;
            for (int i = 1; i <= TYPES; i++) {
                if (cnt[i] > maxx) {
                    maxx = cnt[i];
                    type = i;
                }
            }
            context.write(key, new IntWritable(type));
        }

    }

    public static void solve(String pointin, String pathout)
            throws ClassNotFoundException, InterruptedException {
        try {
            Configuration cfg = HadoopCfg.getConfiguration();
            Job job = Job.getInstance(cfg);
            job.setJobName(JOB_NAME);
            job.setJarByClass(KNNMapReduce.class);

            // mapper
            job.setMapperClass(KNNMapper.class);
            job.setMapOutputKeyClass(Text.class);
            job.setMapOutputValueClass(Text.class);

            // reducer
            job.setReducerClass(KNNReducer.class);
            job.setOutputKeyClass(Text.class);
            job.setOutputValueClass(IntWritable.class);

            FileInputFormat.addInputPath(job, new Path(pointin));
            FileOutputFormat.setOutputPath(job, new Path(pathout));

            job.waitForCompletion(true);

        } catch (IllegalStateException | IllegalArgumentException | IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) throws ClassNotFoundException,
            InterruptedException, IOException {
        initPoints("/knn", "knn-train.txt");
        solve("/knn", "/knn_out");
    }

}
package org.bigdata.mapreduce.knn;

public class DistanceType implements Comparable<DistanceType> {

    private double distance;
    private int type;

    public double getDistance() {
        return distance;
    }

    public void setDistance(double distance) {
        this.distance = distance;
    }

    public int getType() {
        return type;
    }

    public void setType(int type) {
        this.type = type;
    }

    public DistanceType(String s) {
        super();
        String terms[] = s.split(":");
        this.type = Integer.valueOf(terms[0]);
        this.distance = Double.valueOf(terms[1]);
    }

    @Override
    public int compareTo(DistanceType o) {
        return this.getDistance() > o.getDistance() ? 1 : -1;
    }

    @Override
    public String toString() {
        return "DistanceType [distance=" + distance + ", type=" + type + "]";
    }

}
package org.bigdata.mapreduce.knn;

import java.util.Vector;

public class Point {

    private int type;
    private String strpoint;
    private Vector<Double> v = new Vector<>();

    public int getType() {
        return type;
    }

    public void setType(int type) {
        this.type = type;
    }

    public Vector<Double> getV() {
        return v;
    }

    public void setV(Vector<Double> v) {
        this.v = v;
    }

    public Point(String s) {
        super();
        this.strpoint=s;
        String terms[]=s.split(" ");
        for(int i=0,len=terms.length;i<len-1;i++){
            this.v.add(Double.valueOf(terms[i]));
        }
        this.type=Integer.valueOf(terms[terms.length-1]);
    }

    public Point() {
        super();
    }

    @Override
    public String toString() {
        return this.strpoint;
    }

}

输入

1.0 2.0 3.0 1
1.0 2.1 3.1 1
0.9 2.2 2.9 1
3.4 6.7 8.9 2
3.0 7.0 8.7 2
3.3 6.9 8.8 2
2.5 3.3 10.0 3
2.4 2.9 8.0 3
2.1 5.5 7.2 0
1.1 2.5 4.2 0
4.1 3.5 9.2 0
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值