MapReduce实现KNN算法
一、实验任务要求
本实验是为了通过实现K近邻算法(KNN)将所给鸢尾花数据集进行分类。要求根据所给数据,在“训练集大,测试集小”的情况下实现KNN算法,并通过使用Combiner提高算法并行化工作效率,减轻Reducer的负担。
二、实验工具和环境配置说明
电脑安装了Vmware软件,搭建Centos7系统环境,配置了Hadoop2.7.4单机伪分布环境并安装eclipse编程软件,使用Java语言完成实验任务。
数据采用的鸢尾花的demo数据添加链接描述提取码:ada4
三、步骤
1.对问题的分析
KNN算法的基本思想是,将每个无标签的测试数据i与有标签的训练数据j,逐个计算得到距离。将距离排序得到前K项与测试样本i距离最近的训练样本的标签,占比最多的标签类别即为测试样本i的类别。
而现如今,我们要使用MapReduce实现KNN算法,所以就有两种可能性的假设。一是“训练集大,测试集小”,二是“测试集大,训练集小”,课本上实现的是后者,方法比较简单,而我们这里要实现的前者。所以,测试集应该作为缓存文件在每个节点上共享,训练集作为输入文件。首先,在执行map任务之前,调用setup方法将测试集读入,存储为全局变量保存在本地。
在map中读入训练样本,计算与每个测试样本之间的距离distance,map输出的样式为<测试样本的ID,距离distance+训练样本的标签>,注意key是测试样本的ID,因为我们最终在Reduce阶段要确定的是每个测试样本的标签,所以要将测试样本聚合。
Reduce阶段,我们要将同一测试样本与所有训练样本的距离进行排序,选出topK离测试样本最近的训练样本,将这K个训练样本的标签进行计数,计数最大的标签即是测试样本的标签。
优化方法,在处理大量数据时,Map的输出数量会是NM(N为测试样本数量,M为训练样本数量),分布式的情况下会导致通信负担加重。这时,我们可以使用Combiner,减小通信的负担。这里我使用Combiner在节点内部先筛选出topK再发给Reducer。
这里对Combiner中筛选topK会不会引发问题,举例解释一下。比如:本来的分类结果中某节点应该被划分为A类,决定该节点划分情况的样本是所有训练样本中距离该测试样本距离最近的前k个训练样本,那么这样的话,这k个训练样本在分布式的每个节点上也是排名前k个的。总的来说,就是全局排名topK,必然会保证局部排名topK以内。所以我们在Combiner中对每个测试样本,只保留离他最近的topK不会影响最后结果,还能删去许多不重要的工作量与通信量。
2.代码实现
1.设置缓存文件并用setup形成全局变量
设置测试集为缓存文件。
将测试集转化为一个数组,其中每个元素为一个字符串类型的测试样本。
2.Map阶段
首先定义计算距离的方法,这里使用欧式距离。
接下来map阶段,读入一个训练样本,将标签(label)提取出来,将属性变成一个Double类型的数组(train_point);将测试样本逐个转变成Double类型的数组(test_point),使用定义好的计算距离方法计算一下距离distance。将测试样本在测试集中的行偏移量作为键key(此处就是数组中的索引),距离加分隔符再加上标签作为value。也就是<测试ID(key),distance@label(value)>。
3.Combiner阶段
这里采用treemap形式存储同一测试样本下,与训练样本的距离和标签。其中距离作为map的键,标签作为值。Treemap会在形成的过程中自动对键key排序,默认是升序,正好符合我们对距离小的排在前面的要求。将前K项按照刚才的格式输出。(K在整个Iris类中定义)
4.Reduce阶段
首先和Combiner的阶段一样,先排序,再找到前K项最近的,但是这里多了一步对前K项标签进行计数统计,找出数量最多的标签为测试样本的标签。
5.结果展示
代码如下:
import java.io.BufferedReader;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.file.FileStore;
import java.nio.file.PathMatcher;
import java.nio.file.WatchService;
import java.nio.file.attribute.UserPrincipalLookupService;
import java.nio.file.spi.FileSystemProvider;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.HashMap;
import java.util.Set;
import java.util.StringTokenizer;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
public class Iris {
public static class TokenizerMapper
extends Mapper < Object, Text, IntWritable, Text > {
private final static IntWritable one = new IntWritable(1);
private Text word = new Text();
static List < String > test = new ArrayList < String > (); //存储test测试集
@Override
protected void setup(Context context) throws IOException,
InterruptedException { //获取缓存文件路径的数组
Path[] paths = DistributedCache.getLocalCacheFiles(context.getConfiguration());
System.out.println(paths);
BufferedReader sb = new BufferedReader(new FileReader(paths[0].toUri().getPath()));
//读取BufferedReader里面的数据
String tmp = null;
while ((tmp = sb.readLine()) != null) {
test.add(tmp);
}
//关闭sb对象
sb.close();
System.out.println("+++++++" + test);
}
//计算欧式距离
private double Distance(Double[] a, Double[] b) {
// TODO Auto-generated method stub
double sum = 0.0;
for (int i = 0; i < a.length; i++) {
sum += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(sum);
}
public void map(Object key, Text value, Context context) throws IOException,
InterruptedException {
String train[] = value.toString().split(",");
String lable = train[train.length - 1];
//训练集由字符格式转化为Double数组
Double[] train_point = new Double[4];
for (int i = 0; i < train.length - 1; i++) {
train_point[i] = Double.valueOf(train[i]);
}
//测试集由字符格式转化为Double数组
for (int i = 0; i < test.size(); i++) {
String test_poit1[] = test.get(i).toString().split(",");
Double[] test_poit = new Double[4];
for (int j = 0; j < test_poit1.length; j++) {
test_poit[j] = Double.valueOf(test_poit1[j]);
}
//每个测试点的ID作为键key,计算每个测试点与该训练点的距离+"@"+类标签 作为value
context.write(new IntWritable(i), new Text(String.valueOf(Distance(test_poit, train_point)) + "@" + lable));
}
}
}
public static class InvertedIndexCombiner extends Reducer < IntWritable, Text, IntWritable, Text > {
private Text info = new Text();
int k;
protected void setup(Context context) throws IOException,
InterruptedException {
Configuration conf=context.getConfiguration();
k=conf.getInt("K", 1);
}
public void reduce(IntWritable key, Iterable < Text > values,
Context context
) throws IOException,
InterruptedException {
//排序
TreeMap < Double, String > treemap = new TreeMap < Double, String > ();
int sum = 0;
for (Text val: values) {
String distance_lable[] = val.toString().split("@");
for (int i = 0; i < distance_lable.length - 1; i = i + 2) {
treemap.put(Double.valueOf(distance_lable[i]), distance_lable[i + 1]);
//treemap会自动按key升序排序,也就是距离小的排前面
}
}
//得到前k项距离最近
Iterator < Double > it = treemap.keySet().iterator();
Map < String, Integer > map = new HashMap < String, Integer > ();
int num = 0;
String valueinfo="";
while (it.hasNext()) {
Double key1 = it.next();
valueinfo+=String.valueOf(key1)+"@"+treemap.get(key1)+"@";
num++;
if (num >k)
break;
}
context.write(key,new Text(valueinfo));
}
}
public static class IntSumReducer
extends Reducer < IntWritable, Text, IntWritable, Text > {
private Text result = new Text();
int k;
protected void setup(Context context) throws IOException,
InterruptedException {
Configuration conf=context.getConfiguration();
k=conf.getInt("K", 1);
}
public void reduce(IntWritable key, Iterable < Text > values,
Context context
) throws IOException,
InterruptedException {
//排序
TreeMap < Double, String > treemap = new TreeMap < Double, String > ();
int sum = 0;
for (Text val: values) {
String distance_lable[] = val.toString().split("@");
for (int i = 0; i < distance_lable.length - 1; i = i + 2) {
treemap.put(Double.valueOf(distance_lable[i]), distance_lable[i + 1]);
//treemap会自动按key升序排序,也就是距离小的排前面
}
}
//得到前k项距离最近
Iterator < Double > it = treemap.keySet().iterator();
Map < String, Integer > map = new HashMap < String, Integer > ();
int num = 0;
while (it.hasNext()) {
Double key1 = it.next();
if (map.containsKey(treemap.get(key1))) {
int temp = map.get(treemap.get(key1));
map.put(treemap.get(key1), temp + 1);
} else {
map.put(treemap.get(key1), 1);
}
//System.out.println(key1+"="+treemap.get(key1));
num++;
if (num > k)
break;
}
//得到排名最靠前的标签为test的类别
Iterator < String > it1 = map.keySet().iterator();
String lable = it1.next();
int count = map.get(lable);
while (it1.hasNext()) {
String now = it1.next();
if (count < map.get(now)) {
lable = now;
count = map.get(lable);
}
}
result.set(lable);
context.write(key, result);
}
}
public static void main(String[] args) throws Exception {
//任务一
Configuration conf = new Configuration();
//FileSystem hdfs= FileSystem.get(conf);
//conf.set("stop", "hdfs://localhost:9000/input/stopwords.txt");
//DistributedCache.addCacheFile(new URI("hdfs://localhost:9000/discache"), conf);
//String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();
String[] otherArgs = new String[] {
"hdfs://localhost:9000/input/iris/train",
"hdfs://localhost:9000/output/lable"
};
if (otherArgs.length < 2) {
System.err.println("Usage: wordcount <in> [<in>...] <out>");
System.exit(2);
}
conf.setInt("K",5);
Job job1 = Job.getInstance(conf, "word count");
job1.setJarByClass(WordCount.class);
//设置分布式缓存文件
job1.addCacheFile(new URI("hdfs://localhost:9000/input/iris/test/iris_test_data.csv"));
job1.setMapperClass(TokenizerMapper.class);
job1.setCombinerClass(InvertedIndexCombiner.class);
job1.setReducerClass(IntSumReducer.class);
job1.setMapOutputKeyClass(IntWritable.class);
job1.setMapOutputValueClass(Text.class);
job1.setOutputKeyClass(IntWritable.class);
job1.setOutputValueClass(Text.class);
for (int i = 0; i < otherArgs.length - 1; ++i) {
FileInputFormat.addInputPath(job1, new Path(otherArgs[i]));
}
FileOutputFormat.setOutputPath(job1,
new Path(otherArgs[otherArgs.length - 1]));
System.exit(job1.waitForCompletion(true) ? 0 : 1);
}
}