读代码-KMeansDriver

package org.apache.mahout.clustering.kmeans;
public class KMeansDriver extends AbstractJob
kmeans的入口KMeansDriver类


run函数中buildClusters,clusterData

Path clustersOut = buildClusters(conf, input, clustersIn, output, measure, maxIterations, delta, runSequential);
if (runClustering) {
log.info("Clustering data");
clusterData(conf,
input,
clustersOut,
new Path(output, AbstractCluster.CLUSTERED_POINTS_DIR),
measure,
delta,
runSequential);
}


buildClusters函数中提供两种实现

if (runSequential) {
return buildClustersSeq(conf, input, clustersIn, output, measure, maxIterations, delta);
} else {
return buildClustersMR(conf, input, clustersIn, output, measure, maxIterations, delta);
}


buildClustersMR实现了迭代更新中心点的过程

boolean converged = false;
int iteration = 1;
while (!converged && iteration <= maxIterations) {
log.info("K-Means Iteration {}", iteration);
// point the output to a new directory per iteration
Path clustersOut = new Path(output, AbstractCluster.CLUSTERS_DIR + iteration);
converged = runIteration(conf, input, clustersIn, clustersOut, measure.getClass().getName(), delta);
// now point the input to the old output directory
clustersIn = clustersOut;
iteration++;
}


runIteration函数进入了mapred的核心部分

job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(ClusterObservations.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Cluster.class);


输入输出都是sequence file

job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapperClass(KMeansMapper.class);
job.setCombinerClass(KMeansCombiner.class);
job.setReducerClass(KMeansReducer.class);



package org.apache.mahout.clustering.kmeans;
KMeansMapper类
public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>
//启动类
private KMeansClusterer clusterer;
//用于保存聚类中心
private final Collection<Cluster> clusters = new ArrayList<Cluster>();

setup函数加载了距离度量类,初始化KMeansClusterer,载入聚类中心

ClassLoader ccl = Thread.currentThread().getContextClassLoader();
DistanceMeasure measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
.asSubclass(DistanceMeasure.class).newInstance();
measure.configure(conf);

this.clusterer = new KMeansClusterer(measure);

String clusterPath = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
if (clusterPath != null && clusterPath.length() > 0) {
KMeansUtil.configureWithClusterInfo(conf, new Path(clusterPath), clusters);
if (clusters.isEmpty()) {
throw new IllegalStateException("No clusters found. Check your -c path.");
}
}


map函数中启动

this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, context);


KMeansClusterer类,实现算法的核心类
emitPointToNearestCluster函数中
遍历聚类中心,根据距离找到最近点的聚类中心
输出key:最近聚类中心的标识,value:ClusterObservations对点的封装
ClusterObservations中含有s0:向量计数 s1:向量的累和 s2:向量平方的累和
便于后续计算

Cluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (Cluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
if (distance < nearestDistance || nearestCluster == null) {
nearestCluster = cluster;
nearestDistance = distance;
}
}
context.write(new Text(nearestCluster.getIdentifier()), new ClusterObservations(1, point, point.times(point)));



KMeansCombiner类,对map结果进行汇总
public class KMeansCombiner extends Reducer<Text, ClusterObservations, Text, ClusterObservations>
将同一聚类中心下的向量计数,累和

@Override
protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
throws IOException, InterruptedException {
Cluster cluster = new Cluster();
for (ClusterObservations value : values) {
cluster.observe(value);
}
context.write(key, cluster.getObservations());
}


KMeansReducer类,
public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster>
将同一聚类中心下汇总,计算收敛性,重新计算聚类中心
方法是向量平均值,即所有向量累和除以个数。
输出key:聚类中心标识,value:新聚类中心

@Override
protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
throws IOException, InterruptedException {
Cluster cluster = clusterMap.get(key.toString());
for (ClusterObservations delta : values) {
cluster.observe(delta);
}
// force convergence calculation
boolean converged = clusterer.computeConvergence(cluster, convergenceDelta);
if (converged) {
context.getCounter("Clustering", "Converged Clusters").increment(1);
}
cluster.computeParameters();
context.write(new Text(cluster.getIdentifier()), cluster);
}



clusterData函数中可选择两种实现,单机实现和分布式mapred实现

if (runSequential) {
clusterDataSeq(conf, input, clustersIn, output, measure);
} else {
clusterDataMR(conf, input, clustersIn, output, measure, convergenceDelta);
}



clusterDataMR中定义输入输出格式都是sequencefile,输出key为int型,value为vector型

job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(WeightedVectorWritable.class);


只有map作业没有reduce

job.setMapperClass(KMeansClusterMapper.class);
job.setNumReduceTasks(0);



KMeansClusterMapper类
public class KMeansClusterMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable>
private final Collection<Cluster> clusters = new ArrayList<Cluster>();
private KMeansClusterer clusterer;
根据最终聚类标签,将点加上聚类输出

@Override
protected void map(WritableComparable<?> key, VectorWritable point, Context context)
throws IOException, InterruptedException {
clusterer.outputPointWithClusterInfo(point.get(), clusters, context);
}


outputPointWithClusterInfo函数
遍历所有中心,找到最近的,输出
key:聚类id value:WeightedVectorWritable向量


AbstractCluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (AbstractCluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, vector);
if (distance < nearestDistance || nearestCluster == null) {
nearestCluster = cluster;
nearestDistance = distance;
}
}
context.write(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, vector));
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值