用Hadoop实现KMeans算法

在我们阅读的时候,我们首先知道什么是KMeans:
K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。



虽然已经发展到了hadoop2.4,但是对于一些算法只要明白其中的含义,是和语言无关的,无论是使用Java、C++、python等,
本文以Hadoop1.0.3为例。

从理论上来讲用MapReduce技术实现KMeans算法是很Natural的想法:在Mapper中逐个计算样本点离哪个中心最近,然后Emit(样本点所属的簇编号,样本点);在Reducer中属于同一个质心的样本点在一个链表中,方便我们计算新的中心,然后Emit(质心编号,质心)。但是技术上的事并没有理论层面那么简单。

Mapper和Reducer都要用到K个中心(我习惯称之为质心),Mapper要读这些质心,Reducer要写这些质心。另外Mapper还要读存储样本点的数据文件。我先后尝试以下3种方法,只有第3种是可行的,如果你不想被我误导,请直接跳过前两种。

一、用一个共享变量在存储K个质心

由于K很小,所以我们认为用一个Vector<Sample>来存储K个质心是没有问题的。以下代码是错误的:

  1. class MyJob extends Tool{
  2.   static Vector<Sample> centers=new Vector<Sample>(K);
  3.   static class MyMapper extends Mapper{
  4.     //read centers
  5.   } 
  6.   static class MyMapper extends Reducer{
  7.     //update centers
  8.   }
  9.   void run(){
  10.     until ( convergence ){
  11.       map();
  12.       reduce();
  13.     }
  14. }
复制代码

发生这种错误是因为对hadoop执行流程不清楚,对数据流不清楚。简单地说Mapper和Reducer作为MyJob的内部静态类,它们应该是独立的--它们不应该与MyJob有任何交互,因为Mapper和Reducer分别在Task Tracker的不同JVM中运行,而MyJob以及MyJob的内部其他类都在客户端上运行,自然不能在不同的JVM中共享一个变量。

详细的流程是这样的:
首先在客户端上,JVM加载MyJob时先初始化静态变量,执行static块。然后提交作业到Job Tracker。
在Job Tracker上,分配Mapper和Reducer到不同的Task Tracker上。Mapper和Reducer线程获得了MyJob类静态变量的初始拷贝(这份拷贝是指MyJob执行完静态块之后静态变量的模样)。
在Task Tracker上,Mapper和Reducer分别地读写MyJob的静态变量的本地拷贝,但是并不影响原始的MyJob中的静态变量的值。

二、用分布式缓存文件存储K个质心
既然不能通过共享外部类变量的方式,那我们通过文件在map和reduce之间传递数据总可以吧,Mapper从文件中读取质心,Reducer把更新后的质心再写入这个文件。这里的问题是:如果确定要把质心放在文件中,那Mapper就需要从2个文件中读取数据--质心文件和样本数据文件。虽然有MutipleInputs可以指定map()的输入文件有多个,并可以为每个输入文件分别指定解析方式,但是MutipleInputs不能保证每条记录从不同文件中传给map()的顺序。在我们的KMeans中,我们希望质心文件全部被读入后再逐条读入样本数据。

于是乎就想到了DistributedCache,它主要用于Mapper和Reducer之间共享数据。DistributedCacheFile是缓存在本地文件,在Mapper和Reducer中都可使用本地Java I/O的方式读取它。于是我又有了一个错误的思路:

  1. class MyMaper{
  2.     Vector<Sample> centers=new Vector<Sample>(K);
  3.     void setup(){
  4.         //读取cacheFile,给centers赋值
  5.     }
  6.     void map(){
  7.         //计算样本离哪个质心最近
  8.     }
  9. }
  10. class MyReducer{
  11.     Vector<Sample> centers=new Vector<Sample>(K);
  12.     void reduce(){
  13.         //更新centers
  14.     }
  15.     void cleanup(){
  16.         //把centers写回cacheFile
  17.     }
  18. }
复制代码



错因:DistributedCacheFile是只读的,在任务运行前,TaskTracker从JobTracker文件系统复制文件到本地磁盘作为缓存,这是单向的复制,是不能写回的。试想在分布式环境下,如果不同的mapper和reducer可以把缓存文件写回的话,那岂不又需要一套复杂的文件共享机制,严重地影响hadoop执行效率。

三、用分布式缓存文件存储样本数据
其实DistributedCache还有一个特点,它更适合于“大文件”(各节点内存容不下)缓存在本地。仅存储了K个质心的文件显然是小文件,与之相比样本数据文件才是大文件。

此时我们需要2个质心文件:一个存放上一次的质心prevCenterFile,一个存放reducer更新后的质心currCenterFile。Mapper从prevCenterFile中读取质心,Reducer把更新后有质心写入currCenterFile。在Driver中读入prevCenterFile和currCenterFile,比较前后两次的质心是否相同(或足够地接近),如果相同则停止迭代,否则就用currCenterFile覆盖prevCenterFile(使用fs.rename),进入下一次的迭代。

这时候Mapper就是这样的:
  1. class MyMaper{
  2.     Vector<Sample> centers=new Vector<Sample>(K);
  3.     void map(){
  4.         //逐条读取质心,给centers赋值
  5.     }
  6.     void cleanup(){
  7.         //逐行读取cacheFile,计算每个样本点离哪个质心最近
  8.         //然后Emit(样本点所属的簇编号,样本点)
  9.     }
  10. }
复制代码





源代码
试验数据是在Mahout项目中作为example提供的,600个样本点,每个样本是一个60维的浮点向量。  synthetic_control.data.zip (118.04 KB, 下载次数: 1) 

为样本数据建立一个类Sample.java。
  1. package kmeans;

  2. import java.io.DataInput;
  3. import java.io.DataOutput;
  4. import java.io.IOException;

  5. import org.apache.commons.logging.Log;
  6. import org.apache.commons.logging.LogFactory;
  7. import org.apache.hadoop.io.Writable;

  8. public class Sample implements Writable{
  9.     private static final Log log=LogFactory.getLog(Sample.class);
  10.     public static final int DIMENTION=60;
  11.     public double arr[];
  12.      
  13.     public Sample(){
  14.         arr=new double[DIMENTION];
  15.     }
  16.      
  17.     public static double getEulerDist(Sample vec1,Sample vec2){
  18.         if(!(vec1.arr.length==DIMENTION && vec2.arr.length==DIMENTION)){
  19.             log.error("vector's dimention is not "+DIMENTION);
  20.             System.exit(1);
  21.         }
  22.         double dist=0.0;
  23.         for(int i=0;i<DIMENTION;++i){
  24.             dist+=(vec1.arr[i]-vec2.arr[i])*(vec1.arr[i]-vec2.arr[i]);
  25.         }
  26.         return Math.sqrt(dist);
  27.     }
  28.      
  29.     public void clear(){
  30.         for(int i=0;i<arr.length;i++)
  31.             arr[i]=0.0;
  32.     }
  33.      
  34.     @Override
  35.     public String toString(){
  36.         String rect=String.valueOf(arr[0]);
  37.         for(int i=1;i<DIMENTION;i++)
  38.             rect+="\t"+String.valueOf(arr[i]);
  39.         return rect;
  40.     }

  41.     @Override
  42.     public void readFields(DataInput in) throws IOException {
  43.         String str[]=in.readUTF().split("\\s+");
  44.         for(int i=0;i<DIMENTION;++i)
  45.             arr[i]=Double.parseDouble(str[i]);
  46.     }

  47.     @Override
  48.     public void write(DataOutput out) throws IOException {
  49.         out.writeUTF(this.toString());
  50.     }
  51. }
复制代码


KMeans.java
  1. package kmeans;

  2. import java.io.BufferedReader;
  3. import java.io.FileReader;
  4. import java.io.IOException;
  5. import java.io.InputStreamReader;
  6. import java.util.Vector;

  7. import org.apache.commons.logging.Log;
  8. import org.apache.commons.logging.LogFactory;
  9. import org.apache.hadoop.conf.Configuration;
  10. import org.apache.hadoop.conf.Configured;
  11. import org.apache.hadoop.filecache.DistributedCache;
  12. import org.apache.hadoop.fs.FSDataInputStream;
  13. import org.apache.hadoop.fs.FileSystem;
  14. import org.apache.hadoop.fs.Path;
  15. import org.apache.hadoop.io.IntWritable;
  16. import org.apache.hadoop.io.LongWritable;
  17. import org.apache.hadoop.io.Text;
  18. import org.apache.hadoop.mapreduce.Job;
  19. import org.apache.hadoop.mapreduce.Mapper;
  20. import org.apache.hadoop.mapreduce.Reducer;
  21. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
  22. import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
  23. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
  24. import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
  25. import org.apache.hadoop.util.Tool;
  26. import org.apache.hadoop.util.ToolRunner;

  27. public class KMeans extends Configured implements Tool{
  28.     private static final Log log = LogFactory.getLog(KMeans2.class);

  29.     private static final int K = 10;
  30.     private static final int MAXITERATIONS = 300;
  31.     private static final double THRESHOLD = 0.01;
  32.      
  33.     public static boolean stopIteration(Configuration conf) throws IOException{
  34.         FileSystem fs=FileSystem.get(conf);
  35.         Path pervCenterFile=new Path("/user/orisun/input/centers");
  36.         Path currentCenterFile=new Path("/user/orisun/output/part-r-00000");
  37.         if(!(fs.exists(pervCenterFile) && fs.exists(currentCenterFile))){
  38.             log.info("两个质心文件需要同时存在");
  39.             System.exit(1);
  40.         }
  41.         //比较前后两次质心的变化是否小于阈值,决定迭代是否继续
  42.         boolean stop=true;
  43.         String line1,line2;
  44.         FSDataInputStream in1=fs.open(pervCenterFile);
  45.         FSDataInputStream in2=fs.open(currentCenterFile);
  46.         InputStreamReader isr1=new InputStreamReader(in1);
  47.         InputStreamReader isr2=new InputStreamReader(in2);
  48.         BufferedReader br1=new BufferedReader(isr1);
  49.         BufferedReader br2=new BufferedReader(isr2);
  50.         Sample prevCenter,currCenter;
  51.         while((line1=br1.readLine())!=null && (line2=br2.readLine())!=null){
  52.             prevCenter=new Sample();
  53.             currCenter=new Sample();
  54.             String []str1=line1.split("\\s+");
  55.             String []str2=line2.split("\\s+");
  56.             assert(str1[0].equals(str2[0]));
  57.             for(int i=1;i<=Sample.DIMENTION;i++){
  58.                 prevCenter.arr[i-1]=Double.parseDouble(str1[i]);
  59.                 currCenter.arr[i-1]=Double.parseDouble(str2[i]);
  60.             }
  61.             if(Sample.getEulerDist(prevCenter, currCenter)>THRESHOLD){
  62.                 stop=false;
  63.                 break;
  64.             }
  65.         }
  66.         //如果还要进行下一次迭代,就用当前质心替代上一次的质心
  67.         if(stop==false){
  68.             fs.delete(pervCenterFile,true);
  69.             if(fs.rename(currentCenterFile, pervCenterFile)==false){
  70.                 log.error("质心文件替换失败");
  71.                 System.exit(1);
  72.             }
  73.         }
  74.         return stop;
  75.     }
  76.      
  77.     public static class ClusterMapper extends Mapper<LongWritable, Text, IntWritable, Sample> {
  78.         Vector<Sample> centers = new Vector<Sample>();
  79.         @Override
  80.         //清空centers
  81.         public void setup(Context context){
  82.             for (int i = 0; i < K; i++) {
  83.                 centers.add(new Sample());
  84.             }
  85.         }
  86.         @Override
  87.         //从输入文件读入centers
  88.         public void map(LongWritable key, Text value, Context context)
  89.                 throws IOException, InterruptedException {
  90.             String []str=value.toString().split("\\s+");
  91.             if(str.length!=Sample.DIMENTION+1){
  92.                 log.error("读入centers时维度不对");
  93.                 System.exit(1);
  94.             }
  95.             int index=Integer.parseInt(str[0]);
  96.             for(int i=1;i<str.length;i++)
  97.                 centers.get(index).arr[i-1]=Double.parseDouble(str[i]);
  98.         }
  99.         @Override
  100.         //找到每个数据点离哪个质心最近
  101.         public void cleanup(Context context) throws IOException,InterruptedException {
  102.             Path []caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());
  103.             if(caches==null || caches.length<=0){
  104.                 log.error("data文件不存在");
  105.                 System.exit(1);
  106.             }
  107.             BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));
  108.             Sample sample;
  109.             String line;
  110.             while((line=br.readLine())!=null){
  111.                 sample=new Sample();
  112.                 String []str=line.split("\\s+");
  113.                 for(int i=0;i<Sample.DIMENTION;i++)
  114.                     sample.arr[i]=Double.parseDouble(str[i]);
  115.                  
  116.                 int index=-1;
  117.                 double minDist=Double.MAX_VALUE;
  118.                 for(int i=0;i<K;i++){
  119.                     double dist=Sample.getEulerDist(sample, centers.get(i));
  120.                     if(dist<minDist){
  121.                         minDist=dist;
  122.                         index=i;
  123.                     }
  124.                 }
  125.                 context.write(new IntWritable(index), sample);
  126.             }
  127.         }
  128.     }
  129.      
  130.     public static class UpdateCenterReducer extends Reducer<IntWritable, Sample, IntWritable, Sample> {
  131.         int prev=-1;
  132.         Sample center=new Sample();;
  133.         int count=0;
  134.         @Override
  135.         //更新每个质心(除最后一个)
  136.         public void reduce(IntWritable key,Iterable<Sample> values,Context context) throws IOException,InterruptedException{
  137.             while(values.iterator().hasNext()){
  138.                 Sample value=values.iterator().next();
  139.                 if(key.get()!=prev){
  140.                     if(prev!=-1){
  141.                         for(int i=0;i<center.arr.length;i++)
  142.                             center.arr[i]/=count;       
  143.                         context.write(new IntWritable(prev), center);
  144.                     }
  145.                     center.clear();
  146.                     prev=key.get();
  147.                     count=0;
  148.                 }
  149.                 for(int i=0;i<Sample.DIMENTION;i++)
  150.                     center.arr[i]+=value.arr[i];
  151.                 count++;
  152.             }
  153.         }
  154.         @Override
  155.         //更新最后一个质心
  156.         public void cleanup(Context context) throws IOException,InterruptedException{
  157.             for(int i=0;i<center.arr.length;i++)
  158.                 center.arr[i]/=count;
  159.             context.write(new IntWritable(prev), center);
  160.         }
  161.     }

  162.     @Override
  163.     public int run(String[] args) throws Exception {
  164.         Configuration conf=getConf();
  165.         FileSystem fs=FileSystem.get(conf);
  166.         Job job=new Job(conf);
  167.         job.setJarByClass(KMeans.class);
  168.          
  169.         //质心文件每行的第一个数字是索引
  170.         FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
  171.         Path outDir=new Path("/user/orisun/output");
  172.         fs.delete(outDir,true);
  173.         FileOutputFormat.setOutputPath(job, outDir);
  174.          
  175.         job.setInputFormatClass(TextInputFormat.class);
  176.         job.setOutputFormatClass(TextOutputFormat.class);
  177.         job.setMapperClass(ClusterMapper.class);
  178.         job.setReducerClass(UpdateCenterReducer.class);
  179.         job.setOutputKeyClass(IntWritable.class);
  180.         job.setOutputValueClass(Sample.class);
  181.          
  182.         return job.waitForCompletion(true)?0:1;
  183.     }
  184.     public static void main(String[] args) throws Exception {
  185.         Configuration conf = new Configuration();
  186.         FileSystem fs=FileSystem.get(conf);
  187.          
  188.         //样本数据文件中每个样本不需要标记索引
  189.         Path dataFile=new Path("/user/orisun/input/data");
  190.         DistributedCache.addCacheFile(dataFile.toUri(), conf);

  191.         int iteration = 0;
  192.         int success = 1;
  193.         do {
  194.             success ^= ToolRunner.run(conf, new KMeans(), args);
  195.             log.info("iteration "+iteration+" end");
  196.         } while (success == 1 && iteration++ < MAXITERATIONS
  197.                 && (!stopIteration(conf)));
  198.         log.info("Success.Iteration=" + iteration);
  199.          
  200.         //迭代完成后再执行一次mapper,输出每个样本点所属的分类--在/user/orisun/output2/part-m-00000中
  201.         //质心文件保存在/user/orisun/input/centers中
  202.         Job job=new Job(conf);
  203.         job.setJarByClass(KMeans.class);
  204.          
  205.         FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
  206.         Path outDir=new Path("/user/orisun/output2");
  207.         fs.delete(outDir,true);
  208.         FileOutputFormat.setOutputPath(job, outDir);
  209.          
  210.         job.setInputFormatClass(TextInputFormat.class);
  211.         job.setOutputFormatClass(TextOutputFormat.class);
  212.         job.setMapperClass(ClusterMapper.class);
  213.         job.setNumReduceTasks(0);
  214.         job.setOutputKeyClass(IntWritable.class);
  215.         job.setOutputValueClass(Sample.class);
  216.          
  217.         job.waitForCompletion(true);
  218.     }
  219. }
复制代码


注意在Driver中创建Job实例时一定要把Configuration类型的参数传递进去,否则在Mapper或Reducer中调用DistributedCache.getLocalCacheFiles(context.getConfiguration());返回值就为null。因为空构造函数的Job采用的Configuration是从hadoop的配置文件中读出来的(使用new Configuration()创建的Configuration就是从hadoop的配置文件中读出来的),请注意在main()函数中有一句:DistributedCache.addCacheFile(dataFile.toUri(), conf);即此时的Configuration中多了一个DistributedCacheFile,所以你需要把这个Configuration传递给Job构造函数,如果传递默认的Configuration,那在Job中当然不知道DistributedCacheFile的存在了。

Further
方案三还是不如人意,质心文件是很小的(因为质心总共就没几个),用map()函数仅仅是来读一个质心文件根本就没有发挥并行的作用,而且在map()中也没有调用context.write(),所以Mapper中做的事情可以放在Reducer的setup()中来完成,这样就不需要Mapper了,或者说上面设计的就不是MapReduce程序,跟平常的单线程串行程序是一样的。
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值