Flinkml之k最近邻居关联(java实现)

代码:
package cn.xsy.flink.ml;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.base.CrossOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.math.BreezeVectorConverter;
import org.apache.flink.ml.math.DenseVector;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric;
import org.apache.flink.ml.nn.KNN;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictDataSetOperation;
import org.apache.flink.ml.pipeline.TransformDataSetOperation;
import org.apache.flink.ml.preprocessing.MinMaxScaler;
import scala.Tuple2;
import scala.reflect.ClassTag$;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class KNNDemo {

    public static void main(String[] args) throws Exception {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        //训练集
        DataSet<DenseVector> trainingSet = env
                .fromElements("5.8,2.8,5.1,2.4", "6.0,2.2,4.0,1.0", "5.5,4.2,1.4,0.2", "7.3,2.9,6.3,1.8", "5.0,3.4,1.5,0.2",
                        "6.3,3.3,6.0,2.5", "5.0,3.5,1.3,0.3", "6.7,3.1,4.7,1.5", "6.8,2.8,4.8,1.4", "6.1,2.8,4.0,1.3",
                        "6.1,2.6,5.6,1.4", "6.4,3.2,4.5,1.5", "6.1,2.8,4.7,1.2", "6.5,2.8,4.6,1.5", "6.1,2.9,4.7,1.4",
                        "4.9,3.6,1.4,0.1", "6.0,2.9,4.5,1.5", "5.5,2.6,4.4,1.2", "4.8,3.0,1.4,0.3", "5.4,3.9,1.3,0.4")
                //将<String>转换为<DenseVector>
                .map(new Normalizer());
        //测试集
        DataSet<DenseVector> testingSet = env
                .fromElements("5.6,2.8,4.9,2.0", "5.6,3.0,4.5,1.5", "4.8,3.4,1.9,0.2", "4.4,2.9,1.4,0.2", "6.2,2.8,4.8,1.8")
                //将<String>转换为<DenseVector>
                .map(new Normalizer());

        //javadataset转换为scaladataset
        org.apache.flink.api.scala.DataSet<DenseVector> trainingSetScala = new org.apache.flink.api.scala.DataSet<>(
                trainingSet, ClassTag$.MODULE$.apply(DataSet.class));
        org.apache.flink.api.scala.DataSet<DenseVector> testingSetScala = new org.apache.flink.api.scala.DataSet<>(
                testingSet, ClassTag$.MODULE$.apply(DataSet.class));

        //将训练集和测试集缩放到[0,5]之间
//        MinMaxScaler minMaxScaler = new MinMaxScaler();
//        minMaxScaler.setMax(5.0).setMin(0.0);
//        minMaxScaler.fit(trainingSetScala,  ParameterMap.Empty(), (FitOperation) MinMaxScaler.fitVectorMinMaxScaler());
//        TransformDataSetOperation transformDataSetOperation = (TransformDataSetOperation) MinMaxScaler
//                .transformVectors((BreezeVectorConverter) DenseVector.denseVectorConverter(),
//                        TypeInformation.of(DenseVector.class), ClassTag$.MODULE$.apply(DenseVector.class));
//        org.apache.flink.api.scala.DataSet trainingSetScaler = minMaxScaler
//                .transform(trainingSetScala, ParameterMap.Empty(), transformDataSetOperation);
//        org.apache.flink.api.scala.DataSet testingSetScaler = minMaxScaler
//                .transform(testingSetScala, ParameterMap.Empty(), transformDataSetOperation);
//        trainingSetScaler.print();
//        testingSetScaler.print();

        //knn分类器
        KNN knn = new KNN();
        knn.setK(3)
                .setBlocks(10)
                .setDistanceMetric(new SquaredEuclideanDistanceMetric())
                .setUseQuadTree(false)
                .setSizeHint(CrossOperatorBase.CrossHint.SECOND_IS_SMALL);
        //调用knn训练方法
        knn.fit(trainingSetScala, ParameterMap.Empty(),
                (FitOperation) KNN.fitKNN(TypeInformation.of(DenseVector.class)));

        PredictDataSetOperation predictDataSetOperation = (PredictDataSetOperation) KNN
                .predictValues(ClassTag$.MODULE$.apply(DenseVector.class), TypeInformation.of(DenseVector.class));
        // run knn join
        org.apache.flink.api.scala.DataSet<Tuple2<Vector, Vector[]>> predict = knn
                .predict(testingSetScala, ParameterMap.Empty(), predictDataSetOperation);
//        predict.print();
        //打印join结果
        predict.map(new MapFunction<Tuple2<Vector, Vector[]>, Tuple2<Vector, Vector[]>>() {
            @Override
            public Tuple2<Vector, Vector[]> map(Tuple2<Vector, Vector[]> value) throws Exception {
                System.out.print(value._1 + " k近邻为: ");
                for (int i = 0; i < value._2.length; i++) {
                    System.out.print(value._2[i] + " ");
                }
                System.out.println();
                return value;
            }
        }, TypeInformation.of(new TypeHint<Tuple2<Vector, Vector[]>>() {
        }), ClassTag$.MODULE$.apply(Tuple2.class))
                .collect();
    }

    public static final class Normalizer implements MapFunction<String, DenseVector> {

        @Override
        public DenseVector map(String value) throws Exception {
            String[] split = value.split(",");
            Stream<Double> doubleStream = Stream.of(split).map(s -> Double.parseDouble(s));
            List<Double> collect = doubleStream.collect(Collectors.toList());
            double[] doubles = ArrayUtils.toPrimitive(collect.toArray(new Double[collect.size()]));
            return new DenseVector(doubles);
        }
    }

}

运行结果:
DenseVector(4.4, 2.9, 1.4, 0.2) k近邻为: DenseVector(4.8, 3.0, 1.4, 0.3) DenseVector(5.0, 3.4, 1.5, 0.2) DenseVector(5.0, 3.5, 1.3, 0.3) 
DenseVector(5.6, 2.8, 4.9, 2.0) k近邻为: DenseVector(5.8, 2.8, 5.1, 2.4) DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(6.1, 2.9, 4.7, 1.4) 
DenseVector(5.6, 3.0, 4.5, 1.5) k近邻为: DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(5.5, 2.6, 4.4, 1.2) 
DenseVector(4.8, 3.4, 1.9, 0.2) k近邻为: DenseVector(5.0, 3.4, 1.5, 0.2) DenseVector(4.9, 3.6, 1.4, 0.1) DenseVector(4.8, 3.0, 1.4, 0.3) 
DenseVector(6.2, 2.8, 4.8, 1.8) k近邻为: DenseVector(6.1, 2.9, 4.7, 1.4) DenseVector(6.5, 2.8, 4.6, 1.5) DenseVector(6.0, 2.9, 4.5, 1.5) 
官方文档地址:

https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/libs/ml/knn.html#k-nearest-neighbors-join

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值