package org.apache.giraph.benchmark.kmeans;
import java.io.IOException;
import java.util.regex.Pattern;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.Vertex;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
/*
* Kmeans算法代码解释:
* centerPoints[numberOfClusters][numberOfDimensions] ,前面是分群的数量,后面是一个顶点共有几维属性
* 每读取一个顶点vertex后,将其各维的属性值分别与centerPoints每一个群内的各维属性值相比,求使其各维属性的欧拉距离最小的群,然后更新该群的中心值,agg给master.
* 每个任务各自遍历自己的所有vertex,等一次大同步后,centerPoints每个群的各维属性值最终agg一个值作为本轮的centerPoints,然后开始下一轮,该算法需要等最后一个任务完成以获取最终
* agg值。
*
*/
public class KMeansComputation extends BasicComputation<Text, Text, Text, Text> {
Pattern commonSpliter = Pattern.compile(",");
@Override
public void compute(Vertex<Text, Text, Text> vertex, Iterable<Text> messages)
throws IOException {
KMeansNodeWorkerContext workerContext = (KMeansNodeWorkerContext) getWorkerContext();
long superstep = getSuperstep();
int numberOfDim = workerContext.getNumberOfDimensions();
//This is bad I should store the values more native
String origValue = vertex.getValue().toString();
String[] pointsStrings = commonSpliter.split(origValue);
if (superstep >= workerContext.getMaxIterations()) {
vertex.voteToHalt();
} else if (superstep == 0) {
//Here we will determine the globe maxs and mins
double[] points = new double[pointsStrings.length];
for (int i = 0; i < pointsStrings.length; i++) {
points[i] = Double.parseDouble(pointsStrings[i]);
}
for (int i = 0; i < numberOfDim; i++) {
aggregate(Const.MAX_DIMENSION_PREFIX + "." + i, new DoubleWritable(points[i]));
aggregate(Const.MIN_DIMENSION_PREFIX + "." + i, new DoubleWritable(points[i]));
}
} else {
double[] points = parsePointsFromValue(superstep, pointsStrings);
int clusterCenter = selectClusterCenter(workerContext, points); //为points选择合适的群
applyClusterCenterAggregates(clusterCenter, points);
updateValue(vertex, superstep, origValue, clusterCenter);
}
}
private double[] parsePointsFromValue(long superstep, String[] pointsStrings) {
double[] points = new double[pointsStrings.length - (superstep == 1?0:1)];
for (int i = 0; i < points.length; i++) {
points[i] = Double.parseDouble(pointsStrings[i]);
}
return points;
}
private void updateValue( Vertex<Text, Text, Text> vertex,long superstep, String origValue, int clusterCenter) {
//This can be made faster, string are very slow
if (superstep == 1) {
vertex.setValue(new Text(origValue + "," + clusterCenter));
} else {
vertex.setValue(new Text(origValue.substring(0, origValue.lastIndexOf(',')) + "," + clusterCenter));
}
}
/**
* centerPoints二维数据,分别代表群数量和维数量, 该函数依次遍历centerPoints的每一行, 求出points针对这一个群(行)的distance累加和,最后选择累加和最小的群.
* @param workerContext
* @param points
* @return
*/
private int selectClusterCenter(KMeansNodeWorkerContext workerContext, double points[]) {
double[][] centerPoints = workerContext.getCenters();
int selectedCluster = -1;
double shortestDistance = Double.MAX_VALUE;
for (int c = 0; c < centerPoints.length; c++) {
double distance = 0;
for (int d = 0; d < centerPoints[c].length; d++) { //求出points针对这一个群的distance累加和
//This can be made faster
double dimDistance = Math.abs(centerPoints[c][d] - points[d]);
distance = Math.sqrt(Math.pow(distance, 2) + Math.pow(dimDistance, 2));//根号下(distance平方 + dimDistance平方)
}
if (distance < shortestDistance) { //找出累加距离最小的群
selectedCluster = c;
shortestDistance = distance;
}
}
return selectedCluster;
}
private void applyClusterCenterAggregates(int clusterCenter, double[] points) {
for (int d = 0; d < points.length; d++) {
aggregate(Const.MAX_DIMENSION_PREFIX + "." + clusterCenter + "." + d, new DoubleWritable(points[d]));
aggregate(Const.MIN_DIMENSION_PREFIX + "." + clusterCenter + "." + d, new DoubleWritable(points[d]));
}
aggregate(Const.CLUSTER_NODE_COUNT_PREFIX + "." + clusterCenter, new LongWritable(1));
}
}
package org.apache.giraph.benchmark.kmeans;
import java.util.regex.Pattern;
import org.apache.giraph.worker.DefaultWorkerContext;
import org.apache.hadoop.io.Text;
public class KMeansNodeWorkerContext extends DefaultWorkerContext {
private double[][] centers;
private int numberOfClusters;
private int numberOfDimensions;
private int maxIterations;
private final static Pattern commaPattern = Pattern.compile(",");
@Override
public void preApplication() throws InstantiationException,
IllegalAccessException {
numberOfClusters = Integer.parseInt(getContext().getConfiguration().get(Const.NUMBER_OF_CLUSTERS));
numberOfDimensions = Integer.parseInt(getContext().getConfiguration().get(Const.NUMBER_OF_DIMENSIONS));
//System.out.println("numberOfClusters:" + numberOfClusters);
//System.out.println("numberOfDimensions:" + numberOfDimensions);
centers = new double[numberOfClusters][numberOfDimensions];
maxIterations = Integer.parseInt(getContext().getConfiguration().get(Const.MAX_ITERATIONS));
}
@Override
public void postApplication() {
}
/**
* 获取新的中心点集合二维数组, pointIndex = numberOfClusters * numberOfDimensions
*/
@Override
public void preSuperstep() {
Text pointsText = ((Text)getAggregatedValue(Const.CENTER_POINTS));
if (pointsText != null) {
String pointsString = pointsText.toString();
if (!pointsString.isEmpty()) {
String[] points = commaPattern.split(pointsString);
int pointIndex = 0;
for (int c = 0; c < numberOfClusters; c++) {
for (int d = 0; d < numberOfDimensions; d++) {
//System.out.println("centers[" + c + "][" + d + "]=" + points[pointIndex]);
centers[c][d] = Double.parseDouble(points[pointIndex]);
pointIndex++;
}
}
}
}
}
@Override
public void postSuperstep() {
}
public double[][] getCenters() {
return centers;
}
public int getNumberOfClusters() {
return numberOfClusters;
}
public int getNumberOfDimensions() {
return numberOfDimensions;
}
public int getMaxIterations() {
return maxIterations;
}
}