关于kmeans说在前面:kmeans算法有一个硬性的规定就是簇的个数要提前设定。大家可能会质疑这个限制是否影响聚类效果,但是这种担心是多余的。在该算法诞生的这么多年里,该算法已被证明能够广泛的用于解决现实世界问题,即使簇个数k值是次优的,聚类的质量不会受到太大影响。
聚类在现实中很大应用就是对新闻报道进行聚类,以得到顶层类别,如政治、科学、体育、财经等。对此我们倾向于选择比较小的k值,可能10-20之间。如果需要细粒度的主体,则需要更大的k值。为了得到较好的聚类质量,首先需要对k值进行预估。一个最简单粗暴的方法就是基于数据量和需要的簇个数估计,比如我们有100万新闻,我们希望每个类别新闻有500篇,那就可以简单估算k值为1000000/500=2000。
需要明确一点就是kmeans聚类质量的决定因素是使用的距离衡量标准。
关于kmeans 算法思路可以参考:kmeans
算法原理比较简单,现在需要做的是基于mapreduce 框架去实现这个算法。
从理论上来讲用MapReduce技术实现KMeans算法是很Natural的想法:在Mapper中逐个计算样本点离哪个中心最近,然后发出key-value(样本点所属的簇编号,样本点);shuffle后在Reducer中属于同一个质心的样本点在一个list中,方便我们计算新的中心,然后发出新的key-value(质心编号,质心)。但是技术上的事并没有理论层面那么简单。
要实现这个算法需要解决两个问题:
1. 如何存储每次聚类的质心。
2. 如何存储原始聚类数据。
Hadoop中变量或者说数据共享的三种主要方式:
序号 | 方法 |
1 | 使用Configuration的set方法,只适合数据内容比较小的场景 |
2 | 将共享文件放在HDFS上,每次都去读取,效率比较低 |
3 | 将共享文件放在DistributedCache里,在setup初始化一次后,即可多次使用,缺点是不支持修改操作,仅能读取 |
此时我们需要2个质心文件:一个存放上一次的质心prevCenterFile,一个存放reducer更新后的质心currCenterFile。Mapper从prevCenterFile中读取质心,Reducer把更新后有质心写入currCenterFile。在主函数中读入prevCenterFile和currCenterFile,比较前后两次的质心是否相同(或足够地接近),如果相同则停止迭代,否则就用currCenterFile覆prevCenterFile(使用fs.rename),进入下一次的迭代。(PS:其实这种方式效率也不是很高,真正使用spark 基于内存运算会效率更高)
代码参考:kmeans 参考
package kmeans;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
public class Sample implements Writable{
private static final Log log=LogFactory.getLog(Sample.class);
public static final int DIMENTION=60;
public double arr[];
public Sample(){
arr=new double[DIMENTION];
}
public static double getEulerDist(Sample vec1,Sample vec2){
if(!(vec1.arr.length==DIMENTION && vec2.arr.length==DIMENTION)){
log.error("vector's dimention is not "+DIMENTION);
System.exit(1);
}
double dist=0.0;
for(int i=0;i<DIMENTION;++i){
dist+=(vec1.arr[i]-vec2.arr[i])*(vec1.arr[i]-vec2.arr[i]);
}
return Math.sqrt(dist);
}
public void clear(){
for(int i=0;i<arr.length;i++)
arr[i]=0.0;
}
@Override
public String toString(){
String rect=String.valueOf(arr[0]);
for(int i=1;i<DIMENTION;i++)
rect+="\t"+String.valueOf(arr[i]);
return rect;
}
@Override
public void readFields(DataInput in) throws IOException {
String str[]=in.readUTF().split("\\s+");
for(int i=0;i<DIMENTION;++i)
arr[i]=Double.parseDouble(str[i]);
}
@Override
public void write(DataOutput out) throws IOException {
out.writeUTF(this.toString());
}
}
package kmeans;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Vector;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
public class KMeans extends Configured implements Tool{
private static final Log log = LogFactory.getLog(KMeans2.class);
private static final int K = 10;
private static final int MAXITERATIONS = 300;
private static final double THRESHOLD = 0.01;
public static boolean stopIteration(Configuration conf) throws IOException{
FileSystem fs=FileSystem.get(conf);
Path pervCenterFile=new Path("/user/orisun/input/centers");
Path currentCenterFile=new Path("/user/orisun/output/part-r-00000");
if(!(fs.exists(pervCenterFile) && fs.exists(currentCenterFile))){
log.info("两个质心文件需要同时存在");
System.exit(1);
}
//比较前后两次质心的变化是否小于阈值,决定迭代是否继续
boolean stop=true;
String line1,line2;
FSDataInputStream in1=fs.open(pervCenterFile);
FSDataInputStream in2=fs.open(currentCenterFile);
InputStreamReader isr1=new InputStreamReader(in1);
InputStreamReader isr2=new InputStreamReader(in2);
BufferedReader br1=new BufferedReader(isr1);
BufferedReader br2=new BufferedReader(isr2);
Sample prevCenter,currCenter;
while((line1=br1.readLine())!=null && (line2=br2.readLine())!=null){
prevCenter=new Sample();
currCenter=new Sample();
String []str1=line1.split("\\s+");
String []str2=line2.split("\\s+");
assert(str1[0].equals(str2[0]));
for(int i=1;i<=Sample.DIMENTION;i++){
prevCenter.arr[i-1]=Double.parseDouble(str1[i]);
currCenter.arr[i-1]=Double.parseDouble(str2[i]);
}
if(Sample.getEulerDist(prevCenter, currCenter)>THRESHOLD){
stop=false;
break;
}
}
//如果还要进行下一次迭代,就用当前质心替代上一次的质心
if(stop==false){
fs.delete(pervCenterFile,true);
if(fs.rename(currentCenterFile, pervCenterFile)==false){
log.error("质心文件替换失败");
System.exit(1);
}
}
return stop;
}
public static class ClusterMapper extends Mapper<LongWritable, Text, IntWritable, Sample> {
Vector<Sample> centers = new Vector<Sample>();
@Override
//清空centers
public void setup(Context context){
for (int i = 0; i < K; i++) {
centers.add(new Sample());
}
}
@Override
//从输入文件读入centers
public void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
String []str=value.toString().split("\\s+");
if(str.length!=Sample.DIMENTION+1){
log.error("读入centers时维度不对");
System.exit(1);
}
int index=Integer.parseInt(str[0]);
for(int i=1;i<str.length;i++)
centers.get(index).arr[i-1]=Double.parseDouble(str[i]);
}
@Override
//找到每个数据点离哪个质心最近
public void cleanup(Context context) throws IOException,InterruptedException {
Path []caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());
if(caches==null || caches.length<=0){
log.error("data文件不存在");
System.exit(1);
}
BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));
Sample sample;
String line;
while((line=br.readLine())!=null){
sample=new Sample();
String []str=line.split("\\s+");
for(int i=0;i<Sample.DIMENTION;i++)
sample.arr[i]=Double.parseDouble(str[i]);
int index=-1;
double minDist=Double.MAX_VALUE;
for(int i=0;i<K;i++){
double dist=Sample.getEulerDist(sample, centers.get(i));
if(dist<minDist){
minDist=dist;
index=i;
}
}
context.write(new IntWritable(index), sample);
}
}
}
public static class UpdateCenterReducer extends Reducer<IntWritable, Sample, IntWritable, Sample> {
int prev=-1;
Sample center=new Sample();;
int count=0;
@Override
//更新每个质心(除最后一个)
public void reduce(IntWritable key,Iterable<Sample> values,Context context) throws IOException,InterruptedException{
while(values.iterator().hasNext()){
Sample value=values.iterator().next();
if(key.get()!=prev){
if(prev!=-1){
for(int i=0;i<center.arr.length;i++)
center.arr[i]/=count;
context.write(new IntWritable(prev), center);
}
center.clear();
prev=key.get();
count=0;
}
for(int i=0;i<Sample.DIMENTION;i++)
center.arr[i]+=value.arr[i];
count++;
}
}
@Override
//更新最后一个质心
public void cleanup(Context context) throws IOException,InterruptedException{
for(int i=0;i<center.arr.length;i++)
center.arr[i]/=count;
context.write(new IntWritable(prev), center);
}
}
@Override
public int run(String[] args) throws Exception {
Configuration conf=getConf();
FileSystem fs=FileSystem.get(conf);
Job job=new Job(conf);
job.setJarByClass(KMeans.class);
//质心文件每行的第一个数字是索引
FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
Path outDir=new Path("/user/orisun/output");
fs.delete(outDir,true);
FileOutputFormat.setOutputPath(job, outDir);
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setMapperClass(ClusterMapper.class);
job.setReducerClass(UpdateCenterReducer.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(Sample.class);
return job.waitForCompletion(true)?0:1;
}
public static void main(String[] args) throws Exception {
Configuration conf = new Configuration();
FileSystem fs=FileSystem.get(conf);
//样本数据文件中每个样本不需要标记索引
Path dataFile=new Path("/user/orisun/input/data");
DistributedCache.addCacheFile(dataFile.toUri(), conf);
int iteration = 0;
int success = 1;
do {
success ^= ToolRunner.run(conf, new KMeans(), args);
log.info("iteration "+iteration+" end");
} while (success == 1 && iteration++ < MAXITERATIONS
&& (!stopIteration(conf)));
log.info("Success.Iteration=" + iteration);
//迭代完成后再执行一次mapper,输出每个样本点所属的分类--在/user/orisun/output2/part-m-00000中
//质心文件保存在/user/orisun/input/centers中
Job job=new Job(conf);
job.setJarByClass(KMeans.class);
FileInputFormat.setInputPaths(job, "/user/orisun/input/centers");
Path outDir=new Path("/user/orisun/output2");
fs.delete(outDir,true);
FileOutputFormat.setOutputPath(job, outDir);
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
job.setMapperClass(ClusterMapper.class);
job.setNumReduceTasks(0);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(Sample.class);
job.waitForCompletion(true);
}
}