1、在mahout的源码中,每个算法对应的Driver类是算法实现的主要类。
2、 KMeansDriver源码的解读:
public static void main(String[] args) throws Exception { //KmeansDriver的入口类
ToolRunner.run(new Configuration(), new KMeansDriver(), args);//创建一个相应的ToolRunner,执行的其实就是KMeansDriver的run函数。
}
public int run(String[] args) throws Exception {
addInputOption();
addOutputOption();
addOption(DefaultOptionCreator.distanceMeasureOption().create());
addOption(DefaultOptionCreator
.clustersInOption()
.withDescription(
"The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. "
+ "If k is also specified, then a random set of vectors will be selected"
+ " and written out to this path first").create());
addOption(DefaultOptionCreator
.numClustersOption()
......//这一部分进行的是解析命令行中输入的选项
run(getConf(), input, clusters, output, measure, convergenceDelta, maxIterations, runClustering,
clusterClassificationThreshold, runSequential);//解析完开始正式执行函数
return 0;
}
public static void run(Configuration conf, Path input, Path clustersIn, Path output, DistanceMeasure measure,
double convergenceDelta, int maxIterations, boolean runClustering, double clusterClassificationThreshold,
boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException {
// iterate until the clusters converge
String delta = Double.toString(convergenceDelta);
if (log.isInfoEnabled()) {
log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input, clustersIn, output,
measure.getClass().getName()});
log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}", new Object[] {
convergenceDelta, maxIterations, VectorWritable.class.getName()});
}
Path clustersOut = buildClusters(conf, input, clustersIn, output, measure, maxIterations, delta, runSequential);//调用buildClusters函数,选取初始的Cluster节点。
if (runClustering) {
log.info("Clustering data");
clusterData(conf, input, clustersOut, output, measure, clusterClassificationThreshold, runSequential);//开始执行算法。
}
}
public static Path buildClusters(Configuration conf, Path input, Path clustersIn, Path output,
DistanceMeasure measure, int maxIterations, String delta, boolean runSequential) throws IOException,
InterruptedException, ClassNotFoundException {
double convergenceDelta = Double.parseDouble(delta);
List<Cluster> clusters = new ArrayList<Cluster>();
KMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters);//从输入文件里读取cluster的信息
if (clusters.isEmpty()) {//如果不写-c参数的话,就会在这一行返回了>,<
throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument.");
}
Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);//设定Cluster文件路径
ClusteringPolicy policy = new KMeansClusteringPolicy(convergenceDelta);//设定距离的判定值
ClusterClassifier prior = new ClusterClassifier(clusters, policy);//生成ClusterClassifier对象
prior.writeToSeqFiles(priorClustersPath);//把结果写入priorClusterPath
if (runSequential) {
new ClusterIterator().iterateSeq(conf, input, priorClustersPath, output, maxIterations);
} else {
new ClusterIterator().iterateMR(conf, input, priorClustersPath, output, maxIterations);
}
return output;
}
public static void clusterData(Configuration conf, Path input, Path clustersIn, Path output, DistanceMeasure measure,
double clusterClassificationThreshold, boolean runSequential) throws IOException, InterruptedException,
ClassNotFoundException {
if (log.isInfoEnabled()) {//打印信息
log.info("Running Clustering");
log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input, clustersIn, output, measure});
}
ClusterClassifier.writePolicy(new KMeansClusteringPolicy(), clustersIn);
ClusterClassificationDriver.run(input, output, new Path(output, CLUSTERED_POINTS_DIRECTORY),//执行分类
clusterClassificationThreshold, true, runSequential);
}