代码:
package cn.xsy.flink.ml;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.base.CrossOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.math.BreezeVectorConverter;
import org.apache.flink.ml.math.DenseVector;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric;
import org.apache.flink.ml.nn.KNN;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictDataSetOperation;
import org.apache.flink.ml.pipeline.TransformDataSetOperation;
import org.apache.flink.ml.preprocessing.MinMaxScaler;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class KNNDemo {
public static void main(String[] args) throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<DenseVector> trainingSet = env
.fromElements("5.8,2.8,5.1,2.4", "6.0,2.2,4.0,1.0", "5.5,4.2,1.4,0.2", "7.3,2.9,6.3,1.8", "5.0,3.4,1.5,0.2",
"6.3,3.3,6.0,2.5", "5.0,3.5,1.3,0.3", "6.7,3.1,4.7,1.5", "6.8,2.8,4.8,1.4", "6.1,2.8,4.0,1.3",
"6.1,2.6,5.6,1.4", "6.4,3.2,4.5,1.5", "6.1,2.8,4.7,1.2", "6.5,2.8,4.6,1.5", "6.1,2.9,4.7,1.4",
"4.9,3.6,1.4,0.1", "6.0,2.9,4.5,1.5", "5.5,2.6,4.4,1.2", "4.8,3.0,1.4,0.3", "5.4,3.9,1.3,0.4")
.map(new Normalizer());
DataSet<DenseVector> testingSet = env
.fromElements("5.6,2.8,4.9,2.0", "5.6,3.0,4.5,1.5", "4.8,3.4,1.9,0.2", "4.4,2.9,1.4,0.2", "6.2,2.8,4.8,1.8")
.map(new Normalizer());
org.apache.flink.api.scala.DataSet<DenseVector> trainingSetScala = new org.apache.flink.api.scala.DataSet<>(
trainingSet, ClassTag$.MODULE$.apply(DataSet.class));
org.apache.flink.api.scala.DataSet<DenseVector> testingSetScala = new org.apache.flink.api.scala.DataSet<>(
testingSet, ClassTag$.MODULE$.apply(DataSet.class));
KNN knn = new KNN();
knn.setK(3)
.setBlocks(10)
.setDistanceMetric(new SquaredEuclideanDistanceMetric())
.setUseQuadTree(false)
.setSizeHint(CrossOperatorBase.CrossHint.SECOND_IS_SMALL);
knn.fit(trainingSetScala, ParameterMap.Empty(),
(FitOperation) KNN.fitKNN(TypeInformation.of(DenseVector.class)));
PredictDataSetOperation predictDataSetOperation = (PredictDataSetOperation) KNN
.predictValues(ClassTag$.MODULE$.apply(DenseVector.class), TypeInformation.of(DenseVector.class));
org.apache.flink.api.scala.DataSet<Tuple2<Vector, Vector[]>> predict = knn
.predict(testingSetScala, ParameterMap.Empty(), predictDataSetOperation);
predict.map(new MapFunction<Tuple2<Vector, Vector[]>, Tuple2<Vector, Vector[]>>() {
@Override
public Tuple2<Vector, Vector[]> map(Tuple2<Vector, Vector[]> value) throws Exception {
System.out.print(value._1 + " k近邻为: ");
for (int i = 0; i < value._2.length; i++) {
System.out.print(value._2[i] + " ");
}
System.out.println();
return value;
}
}, TypeInformation.of(new TypeHint<Tuple2<Vector, Vector[]>>() {
}), ClassTag$.MODULE$.apply(Tuple2.class))
.collect();
}
public static final class Normalizer implements MapFunction<String, DenseVector> {
@Override
public DenseVector map(String value) throws Exception {
String[] split = value.split(",");
Stream<Double> doubleStream = Stream.of(split).map(s -> Double.parseDouble(s));
List<Double> collect = doubleStream.collect(Collectors.toList());
double[] doubles = ArrayUtils.toPrimitive(collect.toArray(new Double[collect.size()]));
return new DenseVector(doubles);
}
}
}
运行结果:
DenseVector(4.4, 2.9, 1.4, 0.2) k近邻为: DenseVector(4.8, 3.0, 1.4, 0.3) DenseVector(5.0, 3.4, 1.5, 0.2) DenseVector(5.0, 3.5, 1.3, 0.3)
DenseVector(5.6, 2.8, 4.9, 2.0) k近邻为: DenseVector(5.8, 2.8, 5.1, 2.4) DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(6.1, 2.9, 4.7, 1.4)
DenseVector(5.6, 3.0, 4.5, 1.5) k近邻为: DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(6.0, 2.9, 4.5, 1.5) DenseVector(5.5, 2.6, 4.4, 1.2)
DenseVector(4.8, 3.4, 1.9, 0.2) k近邻为: DenseVector(5.0, 3.4, 1.5, 0.2) DenseVector(4.9, 3.6, 1.4, 0.1) DenseVector(4.8, 3.0, 1.4, 0.3)
DenseVector(6.2, 2.8, 4.8, 1.8) k近邻为: DenseVector(6.1, 2.9, 4.7, 1.4) DenseVector(6.5, 2.8, 4.6, 1.5) DenseVector(6.0, 2.9, 4.5, 1.5)
官方文档地址:
https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/libs/ml/knn.html#k-nearest-neighbors-join