flink kmeans聚类算法实现

  • kmeans聚类算法 flink版本
  • 具体实现原理
    K-Means 是迭代的聚类算法,初始设置K个聚类中心
  1. 在每一次迭代过程中,算法计算每个数据点到每个聚类中心的欧式距离
  2. 每个点被分配到它最近的聚类中心
  3. 随后每个聚类中心被移动到所有被分配的点
  4. 移动的聚类中心被分配到下一次迭代
  5. 算法在固定次数的迭代之后终止(在本实现中,参数设置)
  6. 或者聚类中心在迭代中不在移动
  7. 本项目是工作在二维平面的数据点上
  8. 它计算分配给集群中心的数据点
  9. 每个数据点都使用其所属的最终集群(中心)的id进行注释。
package com.eat.dsc.analyze.takeout.shop.mltask;

import com.eat.dsc.analyze.takeout.shop.utils.KMeansData;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;

import java.io.Serializable;
import java.util.List;

/**
 * K-Means 是迭代的聚类算法,初始设置K个聚类中心
 * 在每一次迭代过程中,算法计算每个数据点到每个聚类中心的欧式距离
 * 每个点被分配到它最近的聚类中心
 * 随后每个聚类中心被移动到所有被分配的点
 * 移动的聚类中心被分配到下一次迭代
 * 算法在固定次数的迭代之后终止(在本实现中,参数设置)
 * 或者聚类中心在迭代中不在移动
 * 本项目是工作在二维平面的数据点上
 * 它计算分配给集群中心的数据点
 * 每个数据点都使用其所属的最终集群(中心)的id进行注释。
 * For example <code>"1.2 2.3\n5.3 7.2\n"</code> gives two data points (x=1.2, y=2.3) and (x=5.3, y=7.2).
 * <li>Cluster centers are represented by an integer id and a point value.<br>
 * For example <code>"1 6.2 3.2\n2 2.9 5.7\n"</code> gives two centers (id=1, x=6.2, y=3.2) and (id=2, x=2.9, y=5.7).
 * </ul>
 * <p>Usage: KMeans --points &lt;path&gt; --centroids &lt;path&gt; --output &lt;path&gt; --iterations &lt;n&gt;</code><br>
 * 如果没有参数提供,项目使用默认数据运行聚类程序并迭代10次。
 **/
public class KmeanTask {
    public static void main(String[] args) throws Exception {
            final ParameterTool params = ParameterTool.fromArgs(args);

            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.getConfig().setGlobalJobParameters(params);

            // 获取输入的数据点和聚类中心,如果路径中有数据就读文件,否则取默认数据
            DataSet<Point> points = getPointDataSet(params, env);
            DataSet<Centroid> centroids = getCentroidDataSet(params, env);

            // 设置 K-Means算法的迭代次数
            IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));

        DataSet<Centroid> newCentroids = points
                    // 为每个点(point)计算最近的聚类中心
                    .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
                    // 每个聚类中心的点坐标的计数和求和
                    .map(new CountAppender())
                    .groupBy(0)
                    .reduce(new CentroidAccumulator())
                    // 从点计数和坐标,计算新的聚类中心
                    .map(new CentroidAverager());

            // 将新的中心点放到下一次迭代中,closeWith代表最后一次迭代
            DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
            // 最后将分类和聚类的点生成元组
            DataSet<Tuple2<Integer, Point>> clusteredPoints = points
                    // 将point分派到最后聚类中
                    .map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");

            // 将结果集存到csv文件中或者打印到控制台
            if (params.has("output")) {
                clusteredPoints.writeAsCsv(params.get("output"), "\n", StringUtils.SPACE);

                // since file sinks are lazy, we trigger the execution explicitly
                env.execute("KMeans Example");
            } else {
                System.out.println("Printing result to stdout. Use --output to specify output path.");
                clusteredPoints.print();
            }
        }

        // *************************************************************************
        //     数据源读取 (数据点和聚类中心)
        // *************************************************************************

        private static DataSet<Centroid> getCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
            DataSet<Centroid> centroids;
            if (params.has("centroids")) {
                centroids = env.readCsvFile(params.get("centroids"))
                        .fieldDelimiter(StringUtils.SPACE)
                        .pojoType(Centroid.class, "id", "x", "y");
            } else {
                System.out.println("执行 K-Means 用默认的中心数据集合.");
                System.out.println("Use --centroids to specify file input.");
                centroids = KMeansData.getDefaultCentroidDataSet(env);
            }
            return centroids;
        }

        private static DataSet<Point> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
            DataSet<Point> points;
            if (params.has("points")) {
                // read points from CSV file
                points = env.readCsvFile(params.get("points"))
                        .fieldDelimiter(StringUtils.SPACE)
                        .pojoType(Point.class, "x", "y");
            } else {
                System.out.println("Executing K-Means example with default point data set.");
                System.out.println("Use --points to specify file input.");
                points = KMeansData.getDefaultPointDataSet(env);
            }
            return points;
        }

    // *************************************************************************
    //    数据类型,POJO内部类
    // *************************************************************************

    /**
     * 简单的二维点.
     */
    public static class Point implements Serializable {

        public double x, y;

        public Point() {}

        public Point(double x, double y) {
            this.x = x;
            this.y = y;
        }

        public Point add(Point other) {
            x += other.x;
            y += other.y;
            return this;
        }

        public Point div(long val) {
            x /= val;
            y /= val;
            return this;
        }

        public double euclideanDistance(Point other) {
            return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
        }


        public void clear() {
            x = y = 0.0;
        }

        @Override
        public String toString() {
            return x + StringUtils.SPACE + y;
        }
    }

    /**
     * 简单的二维中心,包括ID的点
     */
    public static class Centroid extends Point {

        public int id;

        public Centroid() {}

        public Centroid(int id, double x, double y) {
            super(x, y);
            this.id = id;
        }

        public Centroid(int id, Point p) {
            super(p.x, p.y);
            this.id = id;
        }

        @Override
        public String toString() {
            return id + " " + super.toString();
        }
    }

        // *************************************************************************
        //     自定义函数
        // *************************************************************************

        /** 从数据点确定最近的聚类中心. */
        @FunctionAnnotation.ForwardedFields("*->1")
        public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
            private List<Centroid> centroids;

            /** 从广播变量中读取聚类中心值到集合中. */
            @Override
            public void open(Configuration parameters) throws Exception {
                this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
            }

            @Override
            public Tuple2<Integer, Point> map(Point p) throws Exception {

                double minDistance = Double.MAX_VALUE;
                int closestCentroidId = -1;

                // 检查所有的聚类中心
                for (Centroid centroid : centroids) {
                    // 计算每个点与聚类中心的距离(欧式距离)
                    double distance = p.euclideanDistance(centroid);

                    // 满足条件更新最近的聚类中心Id
                    if (distance < minDistance) {
                        minDistance = distance;
                        closestCentroidId = centroid.id;
                    }
                }

                // 生成一个包含聚类中心id和数据点的元组tuple.
                return new Tuple2<>(closestCentroidId, p);
            }
        }

        /** 向tupel2追加计数变量. */
        @FunctionAnnotation.ForwardedFields("f0;f1")
        public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {

            @Override
            public Tuple3<Integer/*id*/, Point, Long/*1L*/> map(Tuple2<Integer, Point> t) {
                return new Tuple3<>(t.f0, t.f1, 1L);
            }
        }

        /** 求同一个类所有点的x,y坐标总数和计数点坐标. */
        //@FunctionAnnotation.ForwardedFields("0")
        public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {

            @Override
            public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
                return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
            }
        }

        /** 从坐标和点的个数计算新的聚类中心. */
        //@FunctionAnnotation.ForwardedFields("0->id")
        public static final class CentroidAverager implements MapFunction<Tuple3<Integer/*id*/, Point/*累加的坐标点*/, Long/*个数*/>, Centroid> {

            @Override
            public Centroid map(Tuple3<Integer, Point, Long> value) {
                return new Centroid(value.f0, value.f1.div(value.f2));
            }
        }
    }
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值