基于Hadoop的Kmeans算法实现

    Kmeans算法是很典型的基于距离的聚类算法,采用距离作为相似性的评价指标。即认为两个对象的距离越近,其相似度就越大。该算法认为簇是由距离靠近的对象组成的,因此把得到紧凑且独立的簇作为最终目标。

    算法流程如下:

    1. 从N条数据中随机选取K条数据作为初始聚类中心;

    2. 对剩余的每条数据测量其到每个聚类中心的距离,并将其归到最近的中心的类;

    3. 重新计算已经得到的各个类的聚类中心;

    4. 迭代2~3步,直到新的聚类中心与原聚类中心相等或小于指定的阈值,算法结束。

    理解算法流程之后,需要将上述流程转换为可以在Hadoop集群上运行的MR程序。

    第一步,从N条数据中随机选取K条数据作为初始聚类中心,这个可以利用上一篇中的蓄水池抽样来接近等概率的抽取K条数据。

    第二步,由于第一步随机选取的K条聚类中心的数据相比于整体数据来说所占比例是很小的,因此在MR设计中,将这K条数据用Java API去读取并存放在内存中。在Mapper的setup中需要初始化聚类个数K,分隔符spliter,以及读取初始聚类的数据。然后在map中去计算传进来的每条数据与这K个聚类中心的距离,并得出最短距离以及其所属的类别(用聚类中心的下标表示)。然后将类别以及这条数据传给Reducer端。

    Mapper如下:

public class KMeansMapper extends Mapper<LongWritable, Text, IntWritable, MyWritable> {
	private Logger logger = LoggerFactory.getLogger(KMeansMapper.class);
	private String centerPathStr="";
	private String splitter ="";
	private int k;// 存储聚类中心个数
	private String[] centerVec= null; // 存储聚类中心向量
	
	@Override
	protected void setup(Context context)
			throws IOException, InterruptedException {
		centerPathStr = context.getConfiguration().get(Utils.CENTERPATH);
		splitter = context.getConfiguration().get(Utils.SPLITTER);
		k = context.getConfiguration().getInt(Utils.K, 0);
		centerVec  = new String[k];
		
		// TODO 读取聚类中心到数组 centerVec中
		Path path = new Path(centerPathStr);
		FSDataInputStream is=Utils.getFs().open(path);
		BufferedReader br=new BufferedReader(new InputStreamReader(is));
		String line="";
		int i=0;
		while( (line=br.readLine())!=null){
		 centerVec[i++]=line;
	}
		br.close();
		is.close();

	}
	
	private IntWritable ID = new IntWritable();
	private MyWritable mw = new MyWritable();
	@Override
	protected void map(LongWritable key, Text value, Context context)
			throws IOException, InterruptedException {
		int vecId = getCenterId(value.toString());
		
		ID.set(vecId);
//		logger.info("**********************************"+vecId);
		mw.setData(value.toString());
		context.write(ID, mw);
//		logger.info("ID:{},value:{}",new Object[]{vecId,value});
	}
	
	/**
	 * 计算当前行到聚类中心向量中距离最小的下标;
	 * @param string
	 * @return
	 */
	private int getCenterId(String line) {
		int type=-1;
		double min=Double.MAX_VALUE;
		double distance=0.0;
		for(int i=0;i<centerVec.length;i++){
			distance=Utils.calDistance(line,centerVec[i],splitter);
			if(distance<min){
				min=distance;
				type=i;
			}
		}
		return type;
		}

}
    其中计算两条数据之间的欧式距离如下:

	public static double calDistance(String line, String string,String splitter) {
		double sum=0;
		String[] data=line.split(splitter);
		String[] centerI=string.split(splitter);
		for(int i=0;i<data.length;i++){
			sum+=Math.pow(Double.parseDouble(data[i])-Double.parseDouble(centerI[i]), 2);
		}
		return Math.sqrt(sum);
	}
    自定义值类型MyWritable包含两个变量,一个是当前数据的条数,一个是当前数据的值:
public class MyWritable implements Writable {
	
	private int num = 1;
	private String data;
	public MyWritable() {
		// TODO Auto-generated constructor stub
	}
	public MyWritable(int num, String data){
		this.num = num;
		this.data = data;
	}
	@Override
	public void write(DataOutput out) throws IOException {
		// TODO Auto-generated method stub
		out.writeInt(num);
		out.writeUTF(data);
	}


	@Override
	public void readFields(DataInput in) throws IOException {
		// TODO Auto-generated method stub
		num = in.readInt();
		data = in.readUTF();
	}
	public int getNum() {
		return num;
	}

	public void setNum(int num) {
		this.num = num;
	}

	public String getData() {
		return data;
	}

	public void setData(String data) {
		this.data = data;
	}

}
    在数据量大的情况下,如果需要优化MR,可以添加一个Combiner如下:

public class KmeansCombiner extends Reducer<IntWritable, MyWritable, IntWritable, MyWritable>{
	private String splitter ;
	private Pattern pattern;
	@Override
	protected void setup(Context context)
			throws IOException, InterruptedException {
		splitter = context.getConfiguration().get(SPLITTER);
		pattern = Pattern.compile(",");
	}
	/***
	 * map1 -> (0,"1.1,1.2,1.3"),(0,"1.3,1.2,1.4"),(1,"4.6,5.7,8.8")
	 * combiner1 -> (0,"2.4,2.4,2.7") , (1,"4.6,5.7,8.8")
	 * map2 -> (0,"2.1,2.4,1.2"),(2,"12.1,11.1,13.2"),(2,"14.1,12.3,15.2")
	 * combiner2 ->(0,"2.1,2.4,1.2"),(2,"26.2,23.4,28.4")
	 * 因此 combiner传值给reducer的时候需要传递当前类别的个数
	 * 
	 */
	MyWritable result = new MyWritable();
	@Override
	protected void reduce(IntWritable key, Iterable<MyWritable> values,
			Context context)
			throws IOException, InterruptedException {
		double[] sum=null;
		long  num =0;
		for(MyWritable value:values){
			String[] valStr = pattern.split(value.getData().toString(), -1);
			if(sum==null){// 初始化
				sum=new double[valStr.length];
				addToSum(sum,valStr);// 第一次需要加上
			}else{
			//	对应字段相加
				addToSum(sum,valStr);
			}
			num++;			
		}
		result.setData(format(sum));
		result.setNum((int) num);
		context.write(key, result);
	}
	
	/**
	 * 对应字段相加
	 * @param sum
	 * @param valStr
	 */
		private void addToSum(double[] sum, String[] valStr) {
			//  实现功能
			for(int i=0;i<sum.length;i++){
				sum[i]+=Double.parseDouble(valStr[i]);
			}

		}
		private String format(double[] sum) {
			//完善功能
			String str="";
			for(int i=0;i<sum.length;i++){
				if(i==0){
					str=str.concat(String.valueOf(sum[i]));
				}else{
					str=str.concat(splitter+String.valueOf(sum[i]));
				}
			}
			return str;
		}
}
    在Reducer端,需要做的是,重新计算每个类别的聚类中心。也就是计算每个类别中的所有元素对应位置的和除以当前类别中数据的总条数。

public class KMeansReducer extends Reducer<IntWritable, MyWritable, Text, NullWritable> {

	private String splitter ;
	private Pattern pattern;
	private String[] centerVec = null;
	private int k;
	
	private Logger log = LoggerFactory.getLogger(KMeansReducer.class);
	@Override
	protected void setup(Context context)
			throws IOException, InterruptedException {
		splitter = context.getConfiguration().get(SPLITTER);
		pattern = Pattern.compile(",");
		k= context.getConfiguration().getInt(K, 0);
		centerVec = new String[k];
	}
	
	
	@Override
	protected void reduce(IntWritable key, Iterable<MyWritable> values,
			Context arg2) throws IOException, InterruptedException {
		double[] sum=null;
		long  num =0;
		for(MyWritable value:values){
			int number = value.getNum();
			String[] valStr = pattern.split(value.getData().toString(), -1);
			if(sum==null){// 初始化
				sum=new double[valStr.length];
				addToSum(sum,valStr);// 第一次需要加上
			}else{
			//	对应字段相加
				addToSum(sum,valStr);
			}
			num += number;			
		}
		averageSum(sum,num);
		centerVec[key.get()]= format(sum);
	}
	private Text vec = new Text();
	/**
	 * 直接输出数组centerVec
	 */
	@Override
	protected void cleanup(Context context)
			throws IOException, InterruptedException {
		for(int i=0;i<centerVec.length;i++ ){
			
			vec.set(centerVec[i]);
			context.write(vec, NullWritable.get());
		}
	}

/**
 * 求平均值
 * @param sum
 * @param num
 */
	private void averageSum(double[] sum, long num) {
		//求平均值
		
		for(int i=0;i<sum.length;i++){
			sum[i]=sum[i]/num;
		}
	
	}

/**
 * 对应字段相加
 * @param sum
 * @param valStr
 */
	private void addToSum(double[] sum, String[] valStr) {
		//  实现功能
		for(int i=0;i<sum.length;i++){
			sum[i]+=Double.parseDouble(valStr[i]);
		}

	}

/**
 * 格式化数组
 * 数组元素之间的分隔符采用splitter即可
 * @param sum
 * @return
 */
	private String format(double[] sum) {
		//完善功能
		String str="";
		for(int i=0;i<sum.length;i++){
			if(i==0){
				str=str.concat(String.valueOf(sum[i]));
			}else{
				str=str.concat(splitter+String.valueOf(sum[i]));
			}
		}
		return str;
	}
}

    上面的MR,只是计算了一次新的聚类中心,并且还没有计算新的聚类中心与上一次聚类中心的距离是否满足我们的阈值。因此在需要再写一个job类,来汇总所有的代码。

public class Alljobs {

	public static void main(String[] args) throws Exception {
		// TODO Auto-generated method stub
		String[] KmeansArgs = new String[]{
					"hdfs://master:8020/user/Administrator/Kmeans/k.csv", //原始数据
					"hdfs://master:8020/user/Administrator/Kmeans/k_all", // 输出各个数据的类别
					"3", // 聚几类
					",", // 原始数据分隔符
					"20",// 迭代次数
					"0.5", // 误差阈值
					"0" //start: 0-> 初始化聚类中心,1 -> 计算新的聚类中心  2 -> 是否分类
		};
		String input = KmeansArgs[0];
		String output = KmeansArgs[1];
		int k = Integer.valueOf(KmeansArgs[2]);
		String splitter = KmeansArgs[3];
		int iteration = Integer.valueOf(KmeansArgs[4]);
		double delta = Double.parseDouble(KmeansArgs[5]);
		int start = Integer.valueOf(KmeansArgs[6]);
		int number = 0;
		// 1. 初始化聚类中心向量(SampleJob)
		int ret = -1;
		String fileStr = "iter";
		switch (start){
		case 0 :first(input, output, k, ret);
		case 1: number = updateKmeans(input, output, k, 
				splitter, delta, ret, iteration);
		case 2: if(start == 2){
			number = readLastFile(output,fileStr)-1;
			}
			classify(number, input, output, k, 
				splitter, iteration);
		default: break;
		}
	}
	public static void first(String input,String output,
			int k,
			int ret) throws Exception{
		String[] job1Args = new String[]{
				input,
				output+"/iter0",
				String.valueOf(k)
		};
		ret = ToolRunner.run(Utils.getConf(),new Driver.MyDriver(), job1Args);
		if(ret != 0){
			System.err.println("sample job failed!");
			System.exit(-1);
		}
	}
	// 2. 循环Kmeans,更新聚类中心
			public static int updateKmeans(String input,String output,
					int k,String splitter,double delta,
					int ret,int iteration) throws Exception{ 
				int num = 0;
				for(int i=0;i<iteration;i++){
					String[] jobArgs = new String[]{
							input, // input
							output+"/iter"+(i+1),  //当前聚类中心
							splitter, // splitter
							String.valueOf(k),
							output+"/iter"+i+"/part-r-00000" // 上一次聚类中心
					};
					ret = ToolRunner.run(Utils.getConf(), new KMeansDriver(), jobArgs);
					if(ret != 0){
						System.err.println("kmeans job failed!"+":"+i);
						System.exit(-1);
					}
					if(!Utils.shouldRunNextIteration(output+"/iter"+i+"/part-r-00000",output+"/iter"+(i+1)+"/part-r-00000",
							delta,splitter)){
						num = i+1;
						break;
					}
				}
				return num;
			}
			// 3. 分类
	public static void classify(int num,String input,String output,
			int k,String splitter,
			int iteration) throws IOException, ClassNotFoundException, InterruptedException{
		if (num == 0) {
			num = iteration;
		}
		Configuration conf = Utils.getConf();
		conf.set(SPLITTER, splitter );
		conf.set(CENTERPATH, output+"/iter"+num+"/part-r-00000");
		conf.setInt(K, k);
		Job job = Job.getInstance(conf,"classify");
		job.setMapperClass(Classify.KMeansMapper.class);
		job.setPartitionerClass(Classify.KmeansPartional.class);
		job.setReducerClass(Classify.ClassifyReducer.class);
		job.setMapOutputKeyClass(IntWritable.class);
		job.setMapOutputValueClass(Text.class);
		job.setOutputKeyClass(IntWritable.class);
		job.setOutputValueClass(Text.class);
		job.setNumReduceTasks(k);
		FileInputFormat.addInputPath(job, new Path(input));
		Path out =new Path(output+"/clustered");
	    FileOutputFormat.setOutputPath(job,out);
	    if(Utils.getFs().exists(out)){
	    	Utils.getFs().delete(out, true);
	    }
	    System.exit(job.waitForCompletion(true) ? 0 : 1);
	}
	public static int readLastFile(String output,String fileStr) throws IOException{
//		return Utils.getFs().listStatus(new Path(output), 
//				new PathFilter() {	
//			@Override
//			public boolean accept(Path file) {
//				return file.getName().indexOf(fileStr) == 0;
//			} // fileStr为传进来的final值,就是文件名以某个字符串开头
//		}).length;
		Path path = new Path(output);
		FileStatus[] fs = Utils.getFs().listStatus(path);
		int num = 0;
		for(int i=0;i<fs.length;i++){
			if(fs[i].getPath().getName().startsWith(fileStr)){
				num++;
			}
		}
		return num;	
	}
}

    在第2步中,判断新的聚类中心与上次聚类中心之间的距离是否满足我们初始化时设置的阈值,如果满足则停止,然后对所有数据分类保存。如果不满足则迭代计算直到满足阈值或者迭代次数达到设定的次数。

    在第3步是否分类中,为了将数据按照不同的类别分开保存数据,因此又写了一个MR,并在其中添加了partition来进行分类保存。在新的MR中,基本与上述MR类似,只是少了计算新的聚类中心而已,直接将数据输出即可。

    partition如下:

public class KmeansPartional extends Partitioner<IntWritable, Text>{

	@Override
	public int getPartition(IntWritable key, Text value, int arg2) {
		if(key.get() == 0){
			return 0;
		}else if(key.get() == 1){
			return 1;
		}else{
		return 2;
		}
	}

}

    上述代码即为Kmeans在Hadoop中的实现,代码的实用性还不错。比如你可以通过设定参数,来决定代码只运行计算新的聚类中心,或者只运行最后的分类保存,前提是你需要先有初始聚类中心的数据。

    运行结果如下:




    

  • 4
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值