MapReduce之 K K K-均值聚类(完)
在上一篇博客MapReduce之K-均值聚类(一)中,介绍了 K K K-均值聚类算法的基本原理,接下来讲述如何利用MapReduce实现 K K K-均值聚类算法
MapReduce解决方案
K K K-均值聚类的MapReduce解决方案是一个迭代方案,其中每一次迭代为实现一个MapReduce作业,因此需要创建一个迭代版本的MapReduce作业,因此:
- 使用控制程序来初始化 K K K个质心的位置,迭代调用MapReduce作业,并确定应当继续迭代还是应当停止
- 映射器需要获取数据点和所有簇质心,其中簇中心必须由所有映射器共享
- 当质心不再发生变化或者变化小于阈值时停止
输入数据如下
利用二维坐标来模拟算法输入数据
1.0 2.0
1.0 3.0
1.0 4.0
2.0 5.0
2.0 6.0
2.0 7.0
2.0 8.0
3.0 100.0
3.0 101.0
3.0 102.0
3.0 103.0
3.0 104.0
Mapper阶段任务
这个阶段有两个任务
- 1、使用setup()方法从SequenceFile中将簇质心读入内存
- 迭代处理对应各个输入键-值对的各个簇质心
- 计算欧氏距离并保存与输入点最近的质心
- 写出将由归约器处理的键值对,集中键为离输入点最近的簇质心
Mapper阶段编码
package com.deng.Kmeans;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import java.io.IOException;
import java.util.List;
public class KmeansMapper extends Mapper<LongWritable, Text,PointWritable,PointWritable> {
private List<PointWritable> centers=null;
private static List<PointWritable> readCentersFromSequenceFile() throws IOException {
List<PointWritable> points=KmeansUtil.readFromHDFS("Sequence/read");
// System.out.println("points is "+points.toString());
return points;
}
//从文件系统中读取簇质心
public void setup(Context context) throws IOException {
this.centers=readCentersFromSequenceFile();
}
public void map(LongWritable key, Text value, Context context){
PointWritable v=new PointWritable(value.toString());
PointWritable nearest=null;
double nearestDistance=Double.MAX_VALUE;
for(PointWritable center:centers){
double distance=KmeansUtil.calculate(center,v);
if(nearest==null){
nearest=center;
nearestDistance=distance;
}else{
if(nearestDistance>distance){
nearest=center;
nearestDistance=distance;
}
}
}
try {
context.write(nearest,v);
System.out.println("Mapper key is "+nearest);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
自定义类PointWritable
自定义PointWritable如下
package com.deng.Kmeans;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
public class PointWritable implements Writable, WritableComparable<PointWritable> {
private double x;
private double y;
private Integer id;
public PointWritable(){
this.id=0;
}
public PointWritable(double x,double y){
this.x=x;
this.y=y;
this.id=0;
}
public PointWritable(String s){
set(s);
}
public void set(String s){
String[] num=s.split(" ");
if(num.length==3){
if(!Character.isDigit(num[1].charAt(0))){
num[1]=num[1].substring(1);
}
if(!Character.isDigit(num[2].charAt(num[2].length()-1))){
num[2]=num[2].substring(0,num[2].length()-1);
}
this.x=Double.parseDouble(num[1]);
this.y=Double.parseDouble(num[2]);
this.id=Integer.parseInt(num[0]);
}else if(num.length==2){
if(!Character.isDigit(num[0].charAt(0))){
num[0]=num[0].substring(1);
}
if(!Character.isDigit(num[1].charAt(num[1].length()-1))){
num[1]=num[1].substring(0,num[1].length()-1);
}
this.x=Double.parseDouble(num[0]);
this.y=Double.parseDouble(num[1]);
this.id=0;
}
System.out.println(toString());
}
public double getX() {
return x;
}
public double getY() {
return y;
}
public Integer getId() {
return id;
}
public void setId(int id){
this.id=id;
}
public void add(PointWritable o){
System.out.println(" when add operator , o is "+o);
this.x+=o.getX();
this.y+=o.getY();
System.out.println("the result is "+toString());
}
public void divide(int count){
this.x/=count;
this.y/=count;
}
@Override
public int compareTo(PointWritable o) {
if(this.getId().compareTo(o.getId())!=0){
return this.getId().compareTo(o.getId());
}else return 0;
}
@Override
public void write(DataOutput dataOutput) throws IOException {
dataOutput.writeInt(this.id);
dataOutput.writeDouble(this.x);
dataOutput.writeDouble(this.y);
}
@Override
public void readFields(DataInput dataInput) throws IOException {
this.id=dataInput.readInt();
this.x=dataInput.readDouble();
this.y=dataInput.readDouble();
}
public String toString(){
StringBuilder sb=new StringBuilder();
sb.append(getId()).append(" ").append(getX()).
append(" ").append(getY());
return sb.toString();
}
}
工具类KmeansUtil
KmeansUtil工具类如下:
package com.deng.Kmeans;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.ReflectionUtils;
import org.junit.Test;
import java.io.*;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
public class KmeansUtil {
//初始化簇中心,直接读取前K个数字作为质心
public static List<PointWritable> pick(int k, String p) throws IOException {
System.out.println("init k");
BufferedReader br=new BufferedReader(new FileReader(p));
List<PointWritable> point=new ArrayList<>();
String str;
int t=0;
System.out.println("k is "+k);
while ((str=br.readLine())!=null&&t<k){
System.out.println(str);
point.add(new PointWritable(str));
t++;
System.out.println(point.get(t-1));
}
System.out.println(point.size());
return point;
}
//将簇中心写入到文件系统
public static void writeToHDFS(List<PointWritable> point,String p) throws IOException {
Configuration conf=new Configuration();
FileSystem fs=FileSystem.get(URI.create(p),conf);
Path path=new Path(p);
int t=0;
NullWritable key= NullWritable.get();
PointWritable value=new PointWritable();
SequenceFile.Writer writer=null;
try{
writer=SequenceFile.createWriter(fs,conf,path,key.getClass(),value.getClass());
System.out.println(point.size());
for(int i=0;i<point.size();i++){
t++;
point.get(i).setId(t);
value=new PointWritable(point.get(i).toString());
writer.append(key,value);
}
}catch (NullPointerException e){
e.printStackTrace();
} finally {
IOUtils.closeStream(writer);
}
}
//从文件系统读取簇中心
public static List<PointWritable> readFromHDFS(String p) throws IOException {
Configuration conf=new Configuration();
FileSystem fs=FileSystem.get(URI.create(p),conf);
Path path=new Path(p);
List<PointWritable> points = new ArrayList<>();
SequenceFile.Reader reader=null;
try{
reader=new SequenceFile.Reader(fs,path,conf);
Writable key=(Writable) ReflectionUtils.newInstance(reader.getKeyClass(),conf);
Writable value=(Writable) ReflectionUtils.newInstance(reader.getValueClass(),conf);
long positon=reader.getPosition();
while(reader.next(key,value)){
PointWritable test=new PointWritable(value.toString());
System.out.println("test is "+test);
points.add(test);
positon=reader.getPosition();
}
}finally {
IOUtils.closeStream(reader);
}
System.out.println("reader from HDFS accomplishment");
return points;
}
public static double calculate(PointWritable a,PointWritable b){
return Math.sqrt((a.getX()-b.getX())*(a.getX()-b.getX())+
(a.getY()-b.getY())*(a.getY()-b.getY()));
}
public static double sumOfDistance(List<PointWritable> point){
double sum=0;
for(int i=0;i< point.size();i++){
for(int j=0;j<point.size();j++){
if(i==j) continue;
sum+=calculate(point.get(i),point.get(j));
}
}
return sum;
}
public static double change(List<PointWritable> a,List<PointWritable> b){
return Math.abs(sumOfDistance(a)-sumOfDistance(b));
}
//将MapReduce计算后生成的新簇中心中心写入到文件系统中准备下一次读取
public static void reWriteHDFS(String from,String to) throws IOException {
System.out.println("now ,this operation is rewrite to HDFS, the File directory is "+from);
List<PointWritable> points=readFromHDFS(from);
System.out.println("when rewrite from HDFS ,the centers is "+points.toString());
writeToHDFS(points,to);
}
}
Reducer阶段任务
该阶段任务为重新计算所有簇的平均值,进而重新创建所有簇的质心,计算完成后,键为新的簇质心并写入到SequenceFile中
Reducer阶段编码
package com.deng.Kmeans;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.Reducer;
import java.io.IOException;
public class KmeansReducer extends Reducer<PointWritable,PointWritable, NullWritable,PointWritable> {
public Integer id=0;
public void reduce(PointWritable key,Iterable<PointWritable> values,Context context){
PointWritable newCenter=new PointWritable();
int count=0;
for(PointWritable p:values){
count++;
newCenter.add(p);
}
System.out.println("before divide operator ,the newCenter is "+newCenter+" the count is "+count);
newCenter.divide(count);
newCenter.setId(++id);
System.out.println("after divide operator , the newCenter is "+newCenter);
try {
context.write(NullWritable.get(),newCenter);
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
驱动程序
K K K-均值聚类的驱动程序不仅仅简单包含MapReduce作业,还包括以下部分
- 初始化簇质心
- 保存当前簇质心和MapReduce计算后新的新簇质心
- 如果当前簇质心和新的簇质心之差是在阈值范围内或超过了迭代次数,结束运行,打印当前簇质心,否则将新的簇质心赋值给当前簇质心并写入文件系统进行下一次迭代
package com.deng.Kmeans;
import com.deng.util.FileUtil;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class KmeansDriver {
public static List<PointWritable> initial_centroids,new_centroids,current_centroids;
private static KmeansUtil util;
public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
System.out.println("***************************************");
initial_centroids=new ArrayList<>();
initial_centroids=KmeansUtil.pick(3,"input/kmeans.txt"); //初始化簇质心
System.out.println("init accomplishment");
System.out.println(initial_centroids.toString());
KmeansUtil.writeToHDFS(initial_centroids,"Sequence/read"); //将簇质心写入到文件系统中
System.out.println("write to hdfs accomplishment");
current_centroids=initial_centroids; //将初始化簇中心保存到当前簇中心中
int iterators=0; //判断迭代次数
while (true){
run();
List<PointWritable> p=KmeansUtil.readFromHDFS("Sequence/Output/part-r-00000");
System.out.println(p.toString());
System.out.println("now the MapReduce calculate accomplishment");
KmeansUtil.writeToHDFS(p,"Sequence/read");
System.out.println("now iterators is "+iterators);
if(iterators>10) break; //超过迭代次数停止运行
iterators++;
System.out.println("now ,this operation is deleteDirs ** Sequence/read **");
System.out.println("now ,this operation is rewrite from Sequence/Output/part-00000 to Sequence/read");
KmeansUtil.reWriteHDFS("Sequence/Output/part-r-00000","Sequence/read"); //将簇中心重新写入到文件系统中
new_centroids=KmeansUtil.readFromHDFS("Sequence/read");
System.out.println("newCentorids is "+new_centroids.toString());
if(KmeansUtil.change(new_centroids,current_centroids)<0.0005){ //判断两个簇中心的变化是否在阈值之内
System.out.println(KmeansUtil.change(new_centroids,current_centroids));
break;
}else{
current_centroids=new_centroids;
}
}
//簇中心不再发生变化,打印到控制台
List<PointWritable> result=KmeansUtil.readFromHDFS("Sequence/read");
System.out.println("the result is :");
for(int i=0;i<result.size();i++){
System.out.println(result.get(i).getX()+" "+result.get(i).getY());
}
}
public static void run() throws IOException, ClassNotFoundException, InterruptedException {
FileUtil.deleteDirs("Sequence/Output");
Configuration conf=new Configuration();
String[] otherArgs=new String[]{"input/kmeans.txt","Sequence/Output"};
Job job=new Job(conf,"Kmeans");
job.setJarByClass(KmeansDriver.class);
job.setMapperClass(KmeansMapper.class);
job.setReducerClass(KmeansReducer.class);
job.setMapOutputKeyClass(PointWritable.class);
job.setMapOutputValueClass(PointWritable.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(PointWritable.class);
FileInputFormat.addInputPath(job,new Path(otherArgs[0]));
job.setOutputFormatClass(SequenceFileOutputFormat.class);
SequenceFileOutputFormat.setOutputPath(job,new Path(otherArgs[1]));
int code=job.waitForCompletion(true)?0:1;
}
}
至于为什么会有如此多的System.out.println() ,因为我希望程序可以和大型软件一样实时的展现计算过程而已
遇到的问题
算法思想上没问题,就是在判断迭代次数的时候把条件写反了,检查了两周才找到错误。。。