- kmeans聚类算法 flink版本
- 具体实现原理
K-Means 是迭代的聚类算法,初始设置K个聚类中心
- 在每一次迭代过程中,算法计算每个数据点到每个聚类中心的欧式距离
- 每个点被分配到它最近的聚类中心
- 随后每个聚类中心被移动到所有被分配的点
- 移动的聚类中心被分配到下一次迭代
- 算法在固定次数的迭代之后终止(在本实现中,参数设置)
- 或者聚类中心在迭代中不在移动
- 本项目是工作在二维平面的数据点上
- 它计算分配给集群中心的数据点
- 每个数据点都使用其所属的最终集群(中心)的id进行注释。
package com.eat.dsc.analyze.takeout.shop.mltask;
import com.eat.dsc.analyze.takeout.shop.utils.KMeansData;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import java.io.Serializable;
import java.util.List;
public class KmeanTask {
public static void main(String[] args) throws Exception {
final ParameterTool params = ParameterTool.fromArgs(args);
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().setGlobalJobParameters(params);
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
DataSet<Centroid> newCentroids = points
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
.map(new CountAppender())
.groupBy(0)
.reduce(new CentroidAccumulator())
.map(new CentroidAverager());
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", StringUtils.SPACE);
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
private static DataSet<Centroid> getCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Centroid> centroids;
if (params.has("centroids")) {
centroids = env.readCsvFile(params.get("centroids"))
.fieldDelimiter(StringUtils.SPACE)
.pojoType(Centroid.class, "id", "x", "y");
} else {
System.out.println("执行 K-Means 用默认的中心数据集合.");
System.out.println("Use --centroids to specify file input.");
centroids = KMeansData.getDefaultCentroidDataSet(env);
}
return centroids;
}
private static DataSet<Point> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Point> points;
if (params.has("points")) {
points = env.readCsvFile(params.get("points"))
.fieldDelimiter(StringUtils.SPACE)
.pojoType(Point.class, "x", "y");
} else {
System.out.println("Executing K-Means example with default point data set.");
System.out.println("Use --points to specify file input.");
points = KMeansData.getDefaultPointDataSet(env);
}
return points;
}
public static class Point implements Serializable {
public double x, y;
public Point() {}
public Point(double x, double y) {
this.x = x;
this.y = y;
}
public Point add(Point other) {
x += other.x;
y += other.y;
return this;
}
public Point div(long val) {
x /= val;
y /= val;
return this;
}
public double euclideanDistance(Point other) {
return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
}
public void clear() {
x = y = 0.0;
}
@Override
public String toString() {
return x + StringUtils.SPACE + y;
}
}
public static class Centroid extends Point {
public int id;
public Centroid() {}
public Centroid(int id, double x, double y) {
super(x, y);
this.id = id;
}
public Centroid(int id, Point p) {
super(p.x, p.y);
this.id = id;
}
@Override
public String toString() {
return id + " " + super.toString();
}
}
@FunctionAnnotation.ForwardedFields("*->1")
public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
private List<Centroid> centroids;
@Override
public void open(Configuration parameters) throws Exception {
this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
}
@Override
public Tuple2<Integer, Point> map(Point p) throws Exception {
double minDistance = Double.MAX_VALUE;
int closestCentroidId = -1;
for (Centroid centroid : centroids) {
double distance = p.euclideanDistance(centroid);
if (distance < minDistance) {
minDistance = distance;
closestCentroidId = centroid.id;
}
}
return new Tuple2<>(closestCentroidId, p);
}
}
@FunctionAnnotation.ForwardedFields("f0;f1")
public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
return new Tuple3<>(t.f0, t.f1, 1L);
}
}
public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
}
}
public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {
@Override
public Centroid map(Tuple3<Integer, Point, Long> value) {
return new Centroid(value.f0, value.f1.div(value.f2));
}
}
}