if (runSequential) {
ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);
} else {
ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);
}
public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)
throws IOException, InterruptedException, ClassNotFoundException {
ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath);
Path clustersOut = null;
int iteration = 1;
/* 直到等于迭代次数或isConverged收敛*/
while (iteration <= numIterations) {
conf.set(PRIOR_PATH_KEY, priorPath.toString());
String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath;
Job job = new Job(conf, jobName);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(ClusterWritable.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(ClusterWritable.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapperClass(CIMapper.class);
job.setReducerClass(CIReducer.class);
FileInputFormat.addInputPath(job, inPath);
clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
priorPath = clustersOut;
FileOutputFormat.setOutputPath(job, clustersOut);
job.setJarByClass(ClusterIterator.class);
if (!job.waitForCompletion(true)) {
throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);
}
ClusterClassifier.writePolicy(policy, clustersOut);
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
iteration++;
/* 计算每个Cluster的当前的中心点和本次重新计算出来的中心点的距离,如果都小于给定的convergenceDelta,则本次集群计算收敛*/
if (isConverged(clustersOut, conf, fs)) {
break;
}
}
Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);
}
/* CIMapper中的map方法*/
@Override
protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException,
InterruptedException {
/* 使用ClusterClassifier对当前文章进行分类*/
Vector probabilities = classifier.classify(value.get());
Vector selections = policy.select(probabilities);
for (Element el : selections.nonZeroes()) {
classifier.train(el.index(), value.get(), el.get());
}
}
/* ClusterClassifier中classify方法 */
@Override
public Vector classify(Vector instance) {
return policy.classify(instance, this);
}
/* AbstractClusteringPolicy中的classify方法 */
@Override
public Vector classify(Vector data, ClusterClassifier prior) {
List<Cluster> models = prior.getModels();
int i = 0;
Vector pdfs = new DenseVector(models.size());
/* 用20个集群中心模型对当前文章进行分类并存储在pdfs里面*/
for (Cluster model : models) {
pdfs.set(i++, model.pdf(new VectorWritable(data)));
}
return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
}
/* DistanceMeasureCluster中的pdf方法 */
@Override
public double pdf(VectorWritable vw) {
return 1 / (1 + measure.distance(vw.get(), getCenter()));
}
/* CosineDistanceMeasure中的distance方法,余玄求解2个向量的夹角 */
@Override
public double distance(Vector v1, Vector v2) {
if (v1.size() != v2.size()) {
throw new CardinalityException(v1.size(), v2.size());
}
double lengthSquaredv1 = v1.getLengthSquared();
double lengthSquaredv2 = v2.getLengthSquared();
double dotProduct = v2.dot(v1);
double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);
// correct for floating-point rounding errors
if (denominator < dotProduct) {
denominator = dotProduct;
}
// correct for zero-vector corner case
if (denominator == 0 && dotProduct == 0) {
return 0;
}
return 1.0 - dotProduct / denominator;
}
/* ClusterClassifier的train方法*/
public void train(int actual, Vector data, double weight) {
models.get(actual).observe(new VectorWritable(data), weight);
}
/* AbstractCluster中的observe方法,根据weight给s0计数,s1向量累加,s2向量平方后累加*/
@Override
public void observe(VectorWritable x, double weight) {
observe(x.get(), weight);
}
public void observe(Vector x, double weight) {
if (weight == 1.0) {
observe(x);
} else {
setS0(getS0() + weight);
Vector weightedX = x.times(weight);
if (getS1() == null) {
setS1(weightedX);
} else {
getS1().assign(weightedX, Functions.PLUS);
}
Vector x2 = x.times(x).times(weight);
if (getS2() == null) {
setS2(x2);
} else {
getS2().assign(x2, Functions.PLUS);
}
}
}
/* CIReducer中reduce方法,对这一轮加入集群的向量进行平均,从新计算集群中心*/
@Override
protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException,
InterruptedException {
Iterator<ClusterWritable> iter = values.iterator();
Cluster first = iter.next().getValue(); // there must always be at least one
while (iter.hasNext()) {
Cluster cluster = iter.next().getValue();
first.observe(cluster);
}
List<Cluster> models = Lists.newArrayList();
models.add(first);
classifier = new ClusterClassifier(models, policy);
classifier.close();
context.write(key, new ClusterWritable(first));
} |