1. main函数读取质心文件
2. 将质心的字符串放到configuration中
3. 在mapper类重写setup方法,获取到configuration的质心内容,解析成二维数组的形式,代表质心
4. mapper类中的map方法读取样本文件,跟所有的质心比较,得出每个样本跟哪个质心最近,然后输出<质心,样本>
5. reducer类中重新计算质心,如果重新计算出来的质心跟进来时的质心一致,那么自定义的counter加1
6. main中获取counter的值,看是否等于质心,如果不相等,那么继续迭代,否在退出
package kmeans;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.LineReader;
public class kmeans {
private static String FLAG = "kcluster";
//创建一个字符串标签
public static class TokenizerMapper extends Mapper<Object, Text , Text, Text>{
double[][] centers = new double[Center.k][];
//定义了一个2行的二维数组
String[] centerstrArray = null;
public void setup(Context context){
//将context的聚类中心转换成数组的形式
String kmeansS = context.getConfiguration().get(FLAG);
//拿到存放质心的字符串对象,然后把每个质心分隔开
centerstrArray = kmeansS.split("\t");
//这个数组里面存放的就是质心
for(int i = 0; i < centerstrArray.length; i++){
//遍历这些质心
String[] segs = centerstrArray[i].split(",");
//现在取出来的是每个质心
centers[i] = new double[segs.length];
for (int j = 0; j< segs.length; j++){
centers[i][j] = Double.parseDouble(segs[j]);
//把质心存放在center这个二维数组里面
}
}
}
@Override
protected void map(Object key, Text value, Mapper<Object, Text, Text, Text>.Context context)
throws IOException, InterruptedException {
String line = value.toString();
String [] segs = line.split(",");
double[] sample = new double[segs.length];
for(int i = 0; i < segs.length; i++){
sample[i] = Float.parseFloat(segs[i]);
}
//求距离最短质心
double min = Double.MAX_VALUE;
int index = 0;
for (int i = 0; i < centers.length; i++){
double dis = distance(centers[i],sample);
if(dis < min){
min = dis;
index = i;
}
}
context.write(new Text(centerstrArray[index]), new Text(line));
//找出的质点 和原始数据
}
}
public static class IntSumReducer extends Reducer<Text,Text,NullWritable,Text>{
Counter counter = null;
//增加一个计数器
@Override
protected void reduce(Text key, Iterable<Text> values, Reducer<Text, Text, NullWritable, Text>.Context context)
throws IOException, InterruptedException {
double[] sum = new double[Center.k];
int size = 0;
//计算维度上值得加和,存放在sum数组中
for(Text text :values){
//现在这个values里面的内容是 最近的那个质心的原始数据点的集合
String[] segs = text.toString().split(",");
for(int i = 0 ; i <segs.length; i++){
sum[i] += Double.parseDouble(segs[i]);
}
size ++;
}
//求sum数组中每个维度的平均值,也就是新的质心
StringBuffer sb = new StringBuffer();
for(int i = 0; i <sum.length; i++){
sum[i] /=size;
sb.append(sum[i]);
sb.append(",");
}
//判断新的质点和老的质点是否一样的
boolean flag = true;
String[] centerStrArray = key.toString().split(",");
for(int i = 0; i < centerStrArray.length; i++){
if(Math.abs(Double.parseDouble(centerStrArray[i]) - sum[i]) > 0.00000000001) {
flag = false;
break;
}
}
//如果新的质心和老的质心一样,那么相应的计数器加1;
if(flag){
counter = (Counter) context.getCounter("myCounter", "kmenasCounter");
counter.increment(1);
}
context.write(null, new Text(sb.toString()));
//输出的只有一个值,就是现在的质心
}
}
public static void main(String[] args) throws Exception {
Path kMeansPath = new Path("C:/kMeans/kMeans.txt"); //初始的质心文件
Path samplePath = new Path("C:/kMeans/sample.txt"); //样本文件
// Path kMeansPath = new Path("/kMeans/kMean"); //初始的质心文件
// Path samplePath = new Path("/kMeans/sample"); //样本文
//加载聚类中心文件
Center center = new Center();
String centerString = center.loadInitCenter(kMeansPath);
int index = 0; //迭代的次数
while(index < 5) {
Configuration conf = new Configuration();
conf.set(FLAG, centerString); //将聚类中心的字符串放到configuration中
kMeansPath = new Path("C:/kMeans" + index); //本次迭代的输出路径,也是下一次质心的读取路径
/**判断输出路径是否存在,如果存在,则删除*/
FileSystem hdfs = FileSystem.get(conf);
if(hdfs.exists(kMeansPath)) hdfs.delete(kMeansPath);
Job job = new Job(conf, "kmeans" + index);
job.setJarByClass(kmeans.class);
job.setMapperClass(TokenizerMapper.class);
job.setReducerClass(IntSumReducer.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(Text.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(Text.class);
FileInputFormat.addInputPath(job, samplePath);
FileOutputFormat.setOutputPath(job, kMeansPath);
job.waitForCompletion(true);
//获取自定义counter的大小,如果等于质心的大小,说明质心已经不会发生变化了,则程序停止迭代
long counter = job.getCounters().getGroup("myCounter").findCounter("kmenasCounter").getValue();
if(counter == Center.k) System.exit(0);
/**重新加载质心*/
center = new Center();
centerString = center.loadCenter(kMeansPath);
index ++;
}
System.exit(0);
}
public static double distance(double[] a, double[] b) {
if(a == null || b == null || a.length != b.length) return Double.MAX_VALUE;
double dis = 0;
for(int i = 0; i < a.length; i++) {
dis += Math.pow(a[i] - b[i], 2);
}
return Math.sqrt(dis);
}
}
class Center{
protected static int k = 2;
//定义2个质心,初始的质心
public String loadInitCenter(Path path) throws Exception {
//从初始文件中加载质心,并且质心之间用tab分割
//定义一个方法,把路径用字符串的形式拿出来
StringBuffer sb = new StringBuffer();
//定义一个可变字符串
Configuration conf = new Configuration();
FileSystem hdfs = FileSystem.get(conf);
FSDataInputStream dis = hdfs.open(path);
//写入文件.通过path
LineReader in = new LineReader(dis,conf);
Text line = new Text();
while(in.readLine(line) > 0){
sb.append(line.toString().trim());
//添加字符串并且去掉前后空格
sb.append("\t");
}
return sb.toString().trim();
//把质点用字符串的形式输出
}
public String loadCenter(Path path) throws Exception{
//从每次迭代的文件中读取质心,并且用字符串的形式返回
StringBuffer sb = new StringBuffer();
Configuration conf = new Configuration();
FileSystem hdfs = FileSystem.get(conf);
FileStatus[] files = hdfs.listStatus(path);
//用一个数组来存放文件目录
for (int i = 0; i < files.length; i++){
Path filePath = files[i].getPath();
if(!filePath.getName().contains("part")) continue;
FSDataInputStream dis = hdfs.open(filePath);
LineReader in = new LineReader(dis,conf);
Text line = new Text();
while(in.readLine(line)>0){
sb.append(line.toString().trim());
sb.append("\t");
}
}
return sb.toString().trim();
}
}