Flink实现K-Means算法案例实战

package com.cnic.algorithm.flink.kmeans001;


import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;

import java.io.Serializable;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;

/***
 * 提供使用K-Means示例程序的默认数据集,如果没有为程序提供参数,则使用默认数据集
 */
public class KMeansData {
    public static class Point implements Serializable {
        // x坐标,y坐标
        public double x, y;

        public Point() {}

        public Point(double x, double y) {
            this.x = x;
            this.y = y;
        }
        // 点坐标的加法器
        public Point add(Point other) {
            x += other.x;
            y += other.y;
            return this;
        }

        // 点坐标的除法器
        public Point div(long val) {
            x /= val;
            y /= val;
            return this;
        }
        // 计算点之间的欧式距离
        public double euclideanDistance(Point other) {
            return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
        }

        public void clear() {
            x = y = 0.0;
        }

        @Override
        public String toString() {
            return x + " " + y;
        }
    }

    /**
     * 质心类, 基于点坐标和id.
     */
    public static class Centroid extends Point {

        public int id;

        public Centroid() {}

        public Centroid(int id, double x, double y) {
            super(x, y);
            this.id = id;
        }

        public Centroid(int id, Point p) {
            super(p.x, p.y);
            this.id = id;
        }

        @Override
        public String toString() {
            return id + " " + super.toString();
        }
    }

    /***
     * 簇中心(质心)数据
     */
    public static final Object[][] CENTROIDS = new Object[][]{
            new Object[]{1, -31.85, -44.77},
            new Object[]{2, 35.16, 17.46},
            new Object[]{3, -5.16, 21.93},
            new Object[]{4, -24.06, 6.81}
    };

    /***
     * 输入点的数据集
     *
     */
    public static final Object[][] POINTS = new Object[][]{
            new Object[]{-14.22, -48.01},
            new Object[]{-22.78, 37.10},
            new Object[]{56.18, -42.99},
            new Object[]{35.04, 50.29},
            new Object[]{-9.53, -46.26},
            new Object[]{-34.35, 48.25},
            new Object[]{55.82, -57.49},
            new Object[]{21.03, 54.64},
            new Object[]{-13.63, -42.26},
            new Object[]{-36.57, 32.63},
            new Object[]{50.65, -52.40},
            new Object[]{24.48, 34.04},
            new Object[]{-2.69, -36.02},
            new Object[]{-38.80, 36.58},
            new Object[]{24.00, -53.74},
            new Object[]{32.41, 24.96},
            new Object[]{-4.32, -56.92},
            new Object[]{-22.68, 29.42},
            new Object[]{59.02, -39.56},
            new Object[]{24.47, 45.07},
            new Object[]{5.23, -41.20},
            new Object[]{-23.00, 38.15},
            new Object[]{44.55, -51.50},
            new Object[]{14.62, 59.06},
            new Object[]{7.41, -56.05},
            new Object[]{-26.63, 28.97},
            new Object[]{47.37, -44.72},
            new Object[]{29.07, 51.06},
            new Object[]{0.59, -31.89},
            new Object[]{-39.09, 20.78},
            new Object[]{42.97, -48.98},
            new Object[]{34.36, 49.08},
            new Object[]{-21.91, -49.01},
            new Object[]{-46.68, 46.04},
            new Object[]{48.52, -43.67},
            new Object[]{30.05, 49.25},
            new Object[]{4.03, -43.56},
            new Object[]{-37.85, 41.72},
            new Object[]{38.24, -48.32},
            new Object[]{20.83, 57.85}
    };

    /***
     * 得到默认的质心数据
     */
    public static DataSet<Centroid> getDefaultCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
        List<Centroid> centroidList = new LinkedList<Centroid>();
        // 遍历质心数据
        for (Object[] centroid : CENTROIDS) {
            // 将质心数据集添加到centroidList中
            centroidList.add(
                    new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
        }
        return env.fromCollection(centroidList);
    }

    // 得到默认的点数据
    public static DataSet<Point> getDefaultPointDataSet(ParameterTool params, ExecutionEnvironment env){
        List<Point> pointList = new LinkedList<Point>();
        // 遍历点数据
        for (Object[] point : POINTS) {
            pointList.add(new Point((Double) point[0], (Double) point[1]));
        }
        return env.fromCollection(pointList);
    }


    // 确定数据点最近的集群中心
    @FunctionAnnotation.ForwardedFields("* -> 1")
    public static final class SelectNearestCenter extends RichMapFunction<Point,Tuple2<Integer,Point>>{

        private Collection<Centroid> centroids;
        // 将广播变量中的质心数据集读到集合中

        @Override
        public void open(Configuration parameters) throws Exception {
            this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
        }

        public Tuple2<Integer, Point> map(Point p) throws Exception {
            double minDistance = Double.MAX_VALUE;
            int closestCentroidId = -1;
            // 遍历所有的簇中心
            for (Centroid centroid : centroids) {
                // 计算点和簇中心的欧式距离
                double distance = p.euclideanDistance(centroid);
                // 找到距离点最近的簇中心
                if (distance < minDistance) {
                    minDistance = distance;
                    closestCentroidId = centroid.id;
                }
            }
            //  输出一条新的记录,由簇中心id和point组成
            return new Tuple2<Integer, Point>(closestCentroidId,p);
        }
    }

    /**
     * 对Tuple2<Integer, Point> 进行计数
     *
     */

    /** 对 Tuple2<Integer, Point>进行计数 */
    @FunctionAnnotation.ForwardedFields("f0;f1")
    public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {

        @Override
        public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
            // 对簇内点进行计数
            return new Tuple3<Integer, Point, Long>(t.f0, t.f1, 1L);
        }
    }

    /***
     *  对簇内点计数以及对簇内点的坐标进行累加
     */
    /** 对簇内点计数以及对簇内点的坐标进行累加 */
    @FunctionAnnotation.ForwardedFields("0")
    public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {

        @Override
        public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
            // 这一步逻辑很关键,对簇内点坐标累计,然后对簇内元素个数计数。
            return new Tuple3<Integer, Point, Long>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
        }
    }

    /** 从簇内点的个数和这些点的坐标和计算出新的簇中心*/
    @FunctionAnnotation.ForwardedFields("0->id")
    public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {

        @Override
        public Centroid map(Tuple3<Integer, Point, Long> value) {
            // 坐标和/簇内点个数作为新的簇中心
            return new Centroid(value.f0, value.f1.div(value.f2));
        }
    }


    public static void main(String[] args) throws Exception {
        // 1.解析命令行参数
        final ParameterTool params = ParameterTool.fromArgs(args);
        // 2.构建执行环境
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        // 3.使参数在web界面中可用
        env.getConfig().setGlobalJobParameters(params);
        // 4.得到输入数据:从提供的路径读取点和质心,或返回默认数据
        DataSet<Point> points = getDefaultPointDataSet(params, env);
        DataSet<Centroid> centroids = getDefaultCentroidDataSet(params, env);
        // 5.为K-Means算法设置批量迭代次数
        IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iteration", 10));
        // 6.K-Means算法计算过程
        DataSet<Centroid> newCentroids = points
                // 6.1. 计算每个点距离最近的质心
                .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
                // 6.2. 每个簇内的所有点坐标求和
                .map(new CountAppender())
                .groupBy(0).reduce(new CentroidAccumulator())
                // 6.3. 根据点计数和坐标和计算新的质心
                .map(new CentroidAverager());
        // 7. 将新的质心数据反馈到下一个迭代中
        DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
        // 8. 将点归宿给最终的簇
        // 8. 将点归宿给最终的簇
        DataSet<Tuple2<Integer, Point>> clusteredPoints = points
                .map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

        // 9. 指定输出结果路径和执行
        if (params.has("output")) {
            clusteredPoints.writeAsCsv(params.get("output"), "\n", "data/output");
            env.execute("KMeans Example");
        } else {
            System.out.println("Printing result to stdout. Use --output to specify output path.");
            clusteredPoints.print();
        }

    }


}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值