1、数据说明
本实验采用鸢尾花数据,训练数据:三类鸢尾花,每类有40条数据,共120条数据。测试数据:30条测试数据。
2、实验环境:IntelliJ IDEA 2019.3.1 x64;Hadoop 2.8.5
3、idea代码:
package com.knn;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
public class Knn {
public static class TokenizerMapper
extends Mapper < Object, Text, IntWritable, Text > {
private final static IntWritable one = new IntWritable(1);
private Text word = new Text();
static List<String> test = new ArrayList<String> (); //存储test测试集
@Override
protected void setup(Context context) throws IOException{
//获取缓存文件路径的数组
Path[] paths = DistributedCache.getLocalCacheFiles(context.getConfiguration());
System.out.println("========================paths:"+paths+"====================================");
BufferedReader sb = new BufferedReader(new FileReader(paths[0].toUri().getPath()));
//读取BufferedReader里面的数据
String tmp = null;
while ((tmp = sb.readLine()) != null) {
test.add(tmp);
}
//关闭sb对象
sb.close();
System.out.println("+++++++" + test);
}
/**
* 计算欧式距离
* @param a
* @param b
* @return
*/
private double Distance(Double[] a, Double[] b) {
// TODO Auto-generated method stub
double sum = 0.0;
for (int i = 0; i < a.length; i++) {
sum += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(sum);
}
@Override
public void map(Object key, Text value, Context context) throws IOException,
InterruptedException {
System.out.println("=============="+value+"====================");
String train[] = value.toString().split(",");
String lable = train[train.length - 1];
//训练集由字符格式转化为Double数组
Double[] train_point = new Double[4];
for (int i = 0; i < train.length - 1; i++) {
train_point[i] = Double.valueOf(train[i]);
}
//测试集由字符格式转化为Double数组
for (int i = 0; i < test.size(); i++) {
String test_poit1[] = test.get(i).toString().split(",");
Double[] test_poit = new Double[4];
for (int j = 0; j < test_poit1.length; j++) {
test_poit[j] = Double.valueOf(test_poit1[j]);
}
//每个测试点的ID作为键key,计算每个测试点与该训练点的距离+"@"+类标签 作为value
context.write(new IntWritable(i), new Text(String.valueOf(Distance(test_poit, train_point)) + "@" + lable));
}
}
}
public static class InvertedIndexCombiner extends Reducer < IntWritable, Text, IntWritable, Text > {
private Text info = new Text();
int k;
@Override
protected void setup(Context context){
Configuration conf=context.getConfiguration();
k=conf.getInt("K", 1);
}
@Override
public void reduce(IntWritable key, Iterable < Text > values,
Context context
) throws IOException,
InterruptedException {
//排序
TreeMap < Double, String > treemap = new TreeMap < Double, String > ();
int sum = 0;
for (Text val: values) {
String distance_lable[] = val.toString().split("@");
for (int i = 0; i < distance_lable.length - 1; i = i + 2) {
treemap.put(Double.valueOf(distance_lable[i]), distance_lable[i + 1]);
//treemap会自动按key升序排序,也就是距离小的排前面
}
}
//得到前k项距离最近
Iterator < Double > it = treemap.keySet().iterator();
Map < String, Integer > map = new HashMap < String, Integer > ();
int num = 0;
String valueinfo="";
while (it.hasNext()) {
Double key1 = it.next();
valueinfo+=String.valueOf(key1)+"@"+treemap.get(key1)+"@";
num++;
if (num >k){
break;
}
}
context.write(key,new Text(valueinfo));
}
}
public static class IntSumReducer
extends Reducer < IntWritable, Text, IntWritable, Text > {
private Text result = new Text();
int k;
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
Configuration conf=context.getConfiguration();
k=conf.getInt("K", 1);
}
@Override
public void reduce(IntWritable key, Iterable < Text > values,
Context context
) throws IOException,
InterruptedException {
//排序
TreeMap < Double, String > treemap = new TreeMap < Double, String > ();
int sum = 0;
for (Text val: values) {
String distance_lable[] = val.toString().split("@");
for (int i = 0; i < distance_lable.length - 1; i = i + 2) {
treemap.put(Double.valueOf(distance_lable[i]), distance_lable[i + 1]);
//treemap会自动按key升序排序,也就是距离小的排前面
}
}
//得到前k项距离最近
Iterator < Double > it = treemap.keySet().iterator();
Map < String, Integer > map = new HashMap < String, Integer > ();
int num = 0;
while (it.hasNext()) {
Double key1 = it.next();
if (map.containsKey(treemap.get(key1))) {
int temp = map.get(treemap.get(key1));
map.put(treemap.get(key1), temp + 1);
} else {
map.put(treemap.get(key1), 1);
}
//System.out.println(key1+"="+treemap.get(key1));
num++;
if (num > k){
break;
}
}
//得到排名最靠前的标签为test的类别
Iterator < String > it1 = map.keySet().iterator();
String lable = it1.next();
int count = map.get(lable);
while (it1.hasNext()) {
String now = it1.next();
if (count < map.get(now)) {
lable = now;
count = map.get(lable);
}
}
result.set(lable);
context.write(key, result);
}
}
public static void main(String[] args) throws Exception {
//任务一
Configuration conf = new Configuration();
String[] otherArgs = new String[] {
"hdfs://hadoop:9000/KNN/train",
"hdfs://hadoop:9000/KNN/label"
};
if (otherArgs.length < 2) {
System.err.println("Usage: wordcount <in> [<in>...] <out>");
System.exit(2);
}
conf.setInt("K",10);
Job job = Job.getInstance(conf, "Knn");
job.setJarByClass(Knn.class);
//设置分布式缓存文件
job.addCacheFile(new URI("hdfs://hadoop:9000/KNN/iris/iris_test_data.csv"));
job.setMapperClass(TokenizerMapper.class);
job.setCombinerClass(InvertedIndexCombiner.class);
job.setReducerClass(IntSumReducer.class);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(Text.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(Text.class);
for (int i = 0; i < otherArgs.length - 1; ++i) {
FileInputFormat.addInputPath(job, new Path(otherArgs[i]));
}
FileOutputFormat.setOutputPath(job,new Path(otherArgs[otherArgs.length - 1]));
System.exit(job.waitForCompletion(true) ? -1 : 1);
}
}
4、把代码打成jar包
5、首先启动Hadoop集群:start-dfs.sh start-yarn.sh
6、将数据使用winscp传到虚拟机中
7、在下图的位置新建文件夹名为“knn”(源代码指定好的不要随便修改),
8、在knn目录下建立iris目录和train目录如下图所示(注:label目录不要建立,程序自动创建!!!)
4切换目录
5.将train数据上传到/KNN/train/目录下
6.将test数据上传到/KNN/iris/目录下
7.使用以下命令运行jar包,
8.运行之后会多个label文件夹
9.这个part-r-00000就是预测的结果
10.可以下载下来右键使用Nodpp++查看
最后结果27预测错误
新手上路,请多指教