Flink ML 机器学习算法

文章目录

1 分类

1.1 KNN

K最近邻(KNN)是一种分类算法。KNN的基本假设是,如果提供的样本的大多数最近的K个邻居属于同一标签,则该提供的样本很可能也属于该标签。

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColInteger“label”Label to predict.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Predicted label.

算法参数

Below are the parameters required by KnnModel.

参数默认值类型是否必须描述
k5IntegernoThe number of nearest neighbors.
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.

Knn needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.

代码示例

import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a Knn model and uses it for classification. */
public class KnnExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(2.0, 3.0), 1.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0),
                        Row.of(Vectors.dense(200.1, 300.1), 2.0),
                        Row.of(Vectors.dense(200.2, 300.2), 2.0),
                        Row.of(Vectors.dense(200.3, 300.3), 2.0),
                        Row.of(Vectors.dense(200.4, 300.4), 2.0),
                        Row.of(Vectors.dense(200.4, 300.4), 2.0),
                        Row.of(Vectors.dense(200.6, 300.6), 2.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0),
                        Row.of(Vectors.dense(2.3, 3.2), 1.0),
                        Row.of(Vectors.dense(2.3, 3.2), 1.0),
                        Row.of(Vectors.dense(2.8, 3.2), 3.0),
                        Row.of(Vectors.dense(300., 3.2), 4.0),
                        Row.of(Vectors.dense(2.2, 3.2), 1.0),
                        Row.of(Vectors.dense(2.4, 3.2), 5.0),
                        Row.of(Vectors.dense(2.5, 3.2), 5.0),
                        Row.of(Vectors.dense(2.5, 3.2), 5.0),
                        Row.of(Vectors.dense(2.1, 3.1), 1.0));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(4.0, 4.1), 5.0), Row.of(Vectors.dense(300, 42), 2.0));
        Table predictTable = tEnv.fromDataStream(predictStream).as("features", "label");

        // Creates a Knn object and initializes its parameters.
        Knn knn = new Knn().setK(4);

        // Trains the Knn Model.
        KnnModel knnModel = knn.fit(trainTable);

        // Uses the Knn Model for predictions.
        Table outputTable = knnModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(knn.get|featuresCol|());
            double expectedResult = (|Double|) row.getField(knn.getLabelCol());
            double predictionResult = (|Double|) row.getField(knn.getPredictionCol());
            System.out.printf(
                    "Features: %-15s \tExpected Result: %s \tPrediction Result: %s\n",
                    features, expectedResult, predictionResult);
        }
    }
}

1.2 Linear SVC

Linear Support Vector Machine
Linear Support Vector Machine (Linear SVC) is an algorithm that attempts to find a hyperplane to maximize the distance between classified samples.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColInteger“label”Label to predict.
weightColDouble“weight”Weight of sample.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Label of the max probability.
rawPredictionColVector“rawPrediction”Vector of the probability of each label.

算法参数

Below are the parameters required by LinearSVCModel.

参数默认值类型是否必须描述
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
rawPredictionCol“rawPrediction”StringnoRaw prediction column name.
threshold0.0DoublenoThreshold in binary classification prediction applied to rawPrediction.

LinearSVC needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
weightColnullStringnoWeight column name.
maxIter20IntegernoMaximum number of iterations.
reg0.DoublenoRegularization parameter.
elasticNet0.DoublenoElasticNet parameter.
learningRate0.1DoublenoLearning rate of optimization method.
globalBatchSize32IntegernoGlobal batch size of training algorithms.
tol1e-6DoublenoConvergence tolerance for iterative algorithms.

代码示例

import org.apache.flink.ml.classification.linearsvc.LinearSVC;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a LinearSVC model and uses it for classification. */
public class LinearSVCExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
                        Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
                        Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
                        Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
                        Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
                        Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
                        Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
                        Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
                        Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
                        Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");

        // Creates a LinearSVC object and initializes its parameters.
        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");

        // Trains the LinearSVC Model.
        LinearSVCModel linearSVCModel = linearSVC.fit(inputTable);

        // Uses the LinearSVC Model for predictions.
        Table outputTable = linearSVCModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(linearSVC.getFeaturesCol());
            double expectedResult = (|Double|) row.getField(linearSVC.getLabelCol());
            double predictionResult = (|Double|) row.getField(linearSVC.getPredictionCol());
            DenseVector rawPredictionResult =
                    (DenseVector) row.getField(linearSVC.getRawPredictionCol());
            System.out.printf(
                    "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
                    features, expectedResult, predictionResult, rawPredictionResult);
        }
    }
}

1.3 Logistic Regression

Logistic regression is a special case of the Generalized Linear Model. It is widely used to predict a binary response.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColInteger“label”Label to predict.
weightColDouble“weight”Weight of sample.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Label of the max probability.
rawPredictionColVector“rawPrediction”Vector of the probability of each label.

算法参数

Below are the parameters required by LogisticRegressionModel.|

参数默认值类型是否必须描述
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
rawPredictionCol“rawPrediction”StringnoRaw prediction column name.

LogisticRegression needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
weightColnullStringnoWeight column name.
maxIter20IntegernoMaximum number of iterations.
reg0.DoublenoRegularization parameter.
elasticNet0.DoublenoElasticNet parameter.
learningRate0.1DoublenoLearning rate of optimization method.
globalBatchSize32IntegernoGlobal batch size of training algorithms.
tol1e-6DoublenoConvergence tolerance for iterative algorithms.
multiClass“auto”StringnoClassification type. Supported values: “auto”, “binomial”, “multinomial”.

代码示例

import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a LogisticRegression model and uses it for classification. */
public class LogisticRegressionExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
                        Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
                        Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
                        Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
                        Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
                        Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
                        Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
                        Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
                        Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
                        Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");

        // Creates a LogisticRegression object and initializes its parameters.
        LogisticRegression lr = new LogisticRegression().setWeightCol("weight");

        // Trains the LogisticRegression Model.
        LogisticRegressionModel lrModel = lr.fit(inputTable);

        // Uses the LogisticRegression Model for predictions.
        Table outputTable = lrModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol());
            double expectedResult = (|Double|) row.getField(lr.getLabelCol());
            double predictionResult = (|Double|) row.getField(lr.getPredictionCol());
            DenseVector rawPredictionResult = (DenseVector) row.getField(lr.getRawPredictionCol());
            System.out.printf(
                    "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
                    features, expectedResult, predictionResult, rawPredictionResult);
        }
    }
}

1.4 OnlineLogisticRegression

Online Logistic Regression supports training online regression model on an unbounded stream of training data.

The online optimizer of this algorithm is The FTRL-Proximal proposed by H.Brendan McMahan et al. See H. Brendan McMahan et al., Ad click prediction: a view from the trenches.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector
labelColInteger“label”Label to predict
weightColDouble“weight”Weight of sample

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Label of the max probability
rawPredictionColVector“rawPrediction”Vector of the probability of each label
modelVersionColLong“modelVersion”The version of the model data used for this prediction

算法参数

Below are the parameters required by OnlineLogisticRegressionModel.

参数默认值类型是否必须描述
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
rawPredictionCol“rawPrediction”StringnoRaw prediction column name.
modelVersionCol“modelVersion”StringnoModel version column name.

OnlineLogisticRegression needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
weightColnullStringnoWeight column name.
batchStrategyCOUNT_STRATEGYStringnoStrategy to create mini batch from online train data.
globalBatchSize32IntegernoGlobal batch size of training algorithms.
reg0.DoublenoRegularization parameter.
elasticNet0.DoublenoElasticNet parameter.

代码示例

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/** Simple program that trains an OnlineLogisticRegression model and uses it for classification. */
public class OnlineLogisticRegressionExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data. Both are infinite streams that periodically
        // sends out provided data to trigger model update and prediction.
        List<Row> trainData1 =
                Arrays.asList(
                        Row.of(Vectors.dense(0.1, 2.), 0.),
                        Row.of(Vectors.dense(0.2, 2.), 0.),
                        Row.of(Vectors.dense(0.3, 2.), 0.),
                        Row.of(Vectors.dense(0.4, 2.), 0.),
                        Row.of(Vectors.dense(0.5, 2.), 0.),
                        Row.of(Vectors.dense(11., 12.), 1.),
                        Row.of(Vectors.dense(12., 11.), 1.),
                        Row.of(Vectors.dense(13., 12.), 1.),
                        Row.of(Vectors.dense(14., 12.), 1.),
                        Row.of(Vectors.dense(15., 12.), 1.));

        List<Row> trainData2 =
                Arrays.asList(
                        Row.of(Vectors.dense(0.2, 3.), 0.),
                        Row.of(Vectors.dense(0.8, 1.), 0.),
                        Row.of(Vectors.dense(0.7, 1.), 0.),
                        Row.of(Vectors.dense(0.6, 2.), 0.),
                        Row.of(Vectors.dense(0.2, 2.), 0.),
                        Row.of(Vectors.dense(14., 17.), 1.),
                        Row.of(Vectors.dense(15., 10.), 1.),
                        Row.of(Vectors.dense(16., 16.), 1.),
                        Row.of(Vectors.dense(17., 10.), 1.),
                        Row.of(Vectors.dense(18., 13.), 1.));

        List<Row> predictData =
                Arrays.asList(
                        Row.of(Vectors.dense(0.8, 2.7), 0.0),
                        Row.of(Vectors.dense(15.5, 11.2), 1.0));

        RowTypeInfo typeInfo =
                new RowTypeInfo(
                        new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.DOUBLE},
                        new String[] {"features", "label"});

        SourceFunction<Row> trainSource =
                new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2));
        DataStream<Row> trainStream = env.addSource(trainSource, typeInfo);
        Table trainTable = tEnv.fromDataStream(trainStream).as("features");

        SourceFunction<Row> predictSource =
                new PeriodicSourceFunction(1000, Collections.singletonList(predictData));
        DataStream<Row> predictStream = env.addSource(predictSource, typeInfo);
        Table predictTable = tEnv.fromDataStream(predictStream).as("features");

        // Creates an online LogisticRegression object and initializes its parameters and initial
        // model data.
        Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L);
        Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData));
        OnlineLogisticRegression olr =
                new OnlineLogisticRegression()
                        .setFeaturesCol("features")
                        .setLabelCol("label")
                        .setPredictionCol("prediction")
                        .setReg(0.2)
                        .setElasticNet(0.5)
                        .setGlobalBatchSize(10)
                        .setInitialModelData(initModelDataTable);

        // Trains the online LogisticRegression Model.
        OnlineLogisticRegressionModel onlineModel = olr.fit(trainTable);

        // Uses the online LogisticRegression Model for predictions.
        Table outputTable = onlineModel.transform(predictTable)[0];

        // Extracts and displays the results. As training data stream continuously triggers the
        // update of the internal model data, raw prediction results of the same predict dataset
        // would change over time.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(olr.getFeaturesCol());
            |Double| expectedResult = (|Double|) row.getField(olr.getLabelCol());
            |Double| predictionResult = (|Double|) row.getField(olr.getPredictionCol());
            DenseVector rawPredictionResult = (DenseVector) row.getField(olr.getRawPredictionCol());
            System.out.printf(
                    "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
                    features, expectedResult, predictionResult, rawPredictionResult);
        }
    }
}

1.5 Naive Bayes

Naive Bayes is a multiclass classifier. Based on Bayes’ theorem, it assumes that there is strong (naive) independence between every pair of features.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColInteger“label”Label to predict.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Predicted label.

算法参数

Below are parameters required by NaiveBayesModel.

参数默认值类型是否必须描述
modelType“multinomial”StringnoThe model type. Supported values: “multinomial”.
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.

NaiveBayes needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
smoothing1.0DoublenoThe smoothing parameter.

代码示例

import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a NaiveBayes model and uses it for classification. */
public class NaiveBayesExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(0, 0.), 11),
                        Row.of(Vectors.dense(1, 0), 10),
                        Row.of(Vectors.dense(1, 1.), 10));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(0, 1.)),
                        Row.of(Vectors.dense(0, 0.)),
                        Row.of(Vectors.dense(1, 0)),
                        Row.of(Vectors.dense(1, 1.)));
        Table predictTable = tEnv.fromDataStream(predictStream).as("features");

        // Creates a NaiveBayes object and initializes its parameters.
        NaiveBayes naiveBayes =
                new NaiveBayes()
                        .setSmoothing(1.0)
                        .setFeaturesCol("features")
                        .setLabelCol("label")
                        .setPredictionCol("prediction")
                        .setModelType("multinomial");

        // Trains the NaiveBayes Model.
        NaiveBayesModel naiveBayesModel = naiveBayes.fit(trainTable);

        // Uses the NaiveBayes Model for predictions.
        Table outputTable = naiveBayesModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(naiveBayes.getFeaturesCol());
            double predictionResult = (|Double|) row.getField(naiveBayes.getPredictionCol());
            System.out.printf("Features: %s \tPrediction Result: %s\n", features, predictionResult);
        }
    }
}

2 聚类

2.1 AgglomerativeClustering

AgglomerativeClustering performs a hierarchical clustering using a bottom-up approach. Each observation starts in its own cluster and the clusters are merged together one by one.

The output contains two tables. The first one assigns one cluster Id for each data point. The second one contains the information of merging two clusters at each step. The data format of the merging information is (clusterId1, clusterId2, distance, sizeOfMergedCluster).

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Predicted cluster center.

算法参数

参数默认值类型是否必须描述
numClusters2IntegernoThe max number of clusters to create.
distanceThresholdnullDoublenoThreshold to decide whether two clusters should be merged.
linkage“ward”StringnoCriterion for computing distance between two clusters.
computeFullTreefalseBooleannoWhether computes the full tree after convergence.
distanceMeasure“euclidean”StringnoDistance measure.
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
windowsGlobalWindows.getInstance()WindowsnoWindowing strategy that determines how to create mini-batches from input data.

代码示例

import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */
public class AgglomerativeClusteringExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<DenseVector> inputStream =
			env.fromElements(
				Vectors.dense(1, 1),
				Vectors.dense(1, 4),
				Vectors.dense(1, 0),
				Vectors.dense(4, 1.5),
				Vectors.dense(4, 4),
				Vectors.dense(4, 0));
		Table inputTable = tEnv.fromDataStream(inputStream).as("features");

		// Creates an AgglomerativeClustering object and initializes its parameters.
		AgglomerativeClustering agglomerativeClustering =
			new AgglomerativeClustering()
				.setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)
				.setDistanceMeasure(EuclideanDistanceMeasure.NAME)
				.setPredictionCol("prediction");

		// Uses the AgglomerativeClustering object for clustering.
		Table[] outputs = agglomerativeClustering.transform(inputTable);

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputs[0].execute().collect(); it.hasNext(); ) {
			Row row = it.next();
			DenseVector features =
				(DenseVector) row.getField(agglomerativeClustering.getFeaturesCol());
			int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol());
			System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
		}
	}
}

2.2 K-means

K-means is a commonly-used clustering algorithm. It groups given data points into a predefined number of clusters.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Predicted cluster center.

算法参数

Below are the parameters required by KMeansModel.

参数默认值类型是否必须描述
distanceMeasureeuclideanStringnoDistance measure. Supported values: ‘euclidean’, ‘manhattan’, ‘cosine’.
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
k2IntegernoThe max number of clusters to create.

KMeans needs parameters above and also below.

参数默认值类型是否必须描述
initMode“random”StringnoThe initialization algorithm. Supported options: ‘random’.
seednullLongnoThe random seed.
maxIter20IntegernoMaximum number of iterations.

代码示例

import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a KMeans model and uses it for clustering. */
public class KMeansExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<DenseVector> inputStream =
                env.fromElements(
                        Vectors.dense(0.0, 0.0),
                        Vectors.dense(0.0, 0.3),
                        Vectors.dense(0.3, 0.0),
                        Vectors.dense(9.0, 0.0),
                        Vectors.dense(9.0, 0.6),
                        Vectors.dense(9.6, 0.0));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features");

        // Creates a K-means object and initializes its parameters.
        KMeans kmeans = new KMeans().setK(2).setSeed(1L);

        // Trains the K-means Model.
        KMeansModel kmeansModel = kmeans.fit(inputTable);

        // Uses the K-means Model for predictions.
        Table outputTable = kmeansModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(kmeans.getFeaturesCol());
            int clusterId = (|Integer|) row.getField(kmeans.getPredictionCol());
            System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
        }
    }
}

2.3 Online K-means

Online K-Means extends the function of K-Means, supporting to train a K-Means model continuously according to an unbounded stream of train data.

Online K-Means makes updates with the “mini-batch” K-Means rule, generalized to incorporate forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, Online K-Means computes the new centroids from the weighted average between the original and the estimated centroids. The weight of the estimated centroids is the number of points assigned to them. The weight of the original centroids is also the number of points, but additionally multiplying with the decay factor.

The decay factor scales the contribution of the clusters as estimated thus far. If the decay factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are determined entirely by recent data. Lower values correspond to more forgetting.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Predicted cluster center

算法参数

Below are the parameters required by OnlineKMeansModel.

参数默认值类型是否必须描述
distanceMeasureeuclideanStringnoDistance measure. Supported values: ‘euclidean’, ‘manhattan’, ‘cosine’.
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.
k2IntegernoThe max number of clusters to create.

OnlineKMeans needs parameters above and also below.

参数默认值类型是否必须描述
batchStrategyCOUNT_STRATEGYStringnoStrategy to create mini batch from online train data.
globalBatchSize32IntegernoGlobal batch size of training algorithms.
decayFactor0.DoublenoThe forgetfulness of the previous centroids.
seednullLongnoThe random seed.

代码示例

import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/** Simple program that trains an OnlineKMeans model and uses it for clustering. */
public class OnlineKMeansExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data. Both are infinite streams that periodically
        // sends out provided data to trigger model update and prediction.
        List<Row> trainData1 =
                Arrays.asList(
                        Row.of(Vectors.dense(0.0, 0.0)),
                        Row.of(Vectors.dense(0.0, 0.3)),
                        Row.of(Vectors.dense(0.3, 0.0)),
                        Row.of(Vectors.dense(9.0, 0.0)),
                        Row.of(Vectors.dense(9.0, 0.6)),
                        Row.of(Vectors.dense(9.6, 0.0)));

        List<Row> trainData2 =
                Arrays.asList(
                        Row.of(Vectors.dense(10.0, 100.0)),
                        Row.of(Vectors.dense(10.0, 100.3)),
                        Row.of(Vectors.dense(10.3, 100.0)),
                        Row.of(Vectors.dense(-10.0, -100.0)),
                        Row.of(Vectors.dense(-10.0, -100.6)),
                        Row.of(Vectors.dense(-10.6, -100.0)));

        List<Row> predictData =
                Arrays.asList(
                        Row.of(Vectors.dense(10.0, 10.0)), Row.of(Vectors.dense(-10.0, 10.0)));

        SourceFunction<Row> trainSource =
                new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2));
        DataStream<Row> trainStream =
                env.addSource(trainSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features");

        SourceFunction<Row> predictSource =
                new PeriodicSourceFunction(1000, Collections.singletonList(predictData));
        DataStream<Row> predictStream =
                env.addSource(predictSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE));
        Table predictTable = tEnv.fromDataStream(predictStream).as("features");

        // Creates an online K-means object and initializes its parameters and initial model data.
        OnlineKMeans onlineKMeans =
                new OnlineKMeans()
                        .setFeaturesCol("features")
                        .setPredictionCol("prediction")
                        .setGlobalBatchSize(6)
                        .setInitialModelData(
                                KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0, 0));

        // Trains the online K-means Model.
        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);

        // Uses the online K-means Model for predictions.
        Table outputTable = onlineModel.transform(predictTable)[0];

        // Extracts and displays the results. As training data stream continuously triggers the
        // update of the internal k-means model data, clustering results of the same predict dataset
        // would change over time.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row1 = it.next();
            DenseVector features1 = (DenseVector) row1.getField(onlineKMeans.getFeaturesCol());
            |Integer| clusterId1 = (|Integer|) row1.getField(onlineKMeans.getPredictionCol());
            Row row2 = it.next();
            DenseVector features2 = (DenseVector) row2.getField(onlineKMeans.getFeaturesCol());
            |Integer| clusterId2 = (|Integer|) row2.getField(onlineKMeans.getPredictionCol());
            if (Objects.equals(clusterId1, clusterId2)) {
                System.out.printf("%s and %s are now in the same cluster.\n", features1, features2);
            } else {
                System.out.printf(
                        "%s and %s are now in different clusters.\n", features1, features2);
            }
        }
    }
}

3 评估

3.1 Binary Classification Evaluator

Binary Classification Evaluator calculates the evaluation metrics for binary classification. The input data has rawPrediction, label, and an optional weight column. The rawPrediction can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities). The output may contain different metrics defined by the parameter MetricsNames.

输入列

参数名称类型默认值描述
labelColNumber“label”The label of this entry.
rawPredictionColVector/NumberrawPredictionThe raw prediction result.
weightColNumbernullThe weight of this entry.

输出列

参数名称类型描述
“areaUnderROC”DoubleThe area under the receiver operating characteristic (ROC) curve.
“areaUnderPR”DoubleThe area under the precision-recall curve.
“areaUnderLorenz”DoubleKolmogorov-Smirnov, measures the ability of the model to separate positive and negative samples.
“ks”DoubleThe area under the lorenz curve.

算法参数

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
weightColnullStringnoWeight column name.
rawPredictionCol“rawPrediction”StringnoRaw prediction column name.
metricsNames[“areaUnderROC”, “areaUnderPR”]String[]noNames of the output metrics. Supported values: ‘areaUnderROC’, ‘areaUnderPR’, ‘areaUnderLorenz’, ‘ks’.

代码示例

import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator;
import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluatorParams;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;

/**
 * Simple program that creates a BinaryClassificationEvaluator instance and uses it for evaluation.
 */
public class BinaryClassificationEvaluatorExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(1.0, Vectors.dense(0.1, 0.9)),
                        Row.of(1.0, Vectors.dense(0.2, 0.8)),
                        Row.of(1.0, Vectors.dense(0.3, 0.7)),
                        Row.of(0.0, Vectors.dense(0.25, 0.75)),
                        Row.of(0.0, Vectors.dense(0.4, 0.6)),
                        Row.of(1.0, Vectors.dense(0.35, 0.65)),
                        Row.of(1.0, Vectors.dense(0.45, 0.55)),
                        Row.of(0.0, Vectors.dense(0.6, 0.4)),
                        Row.of(0.0, Vectors.dense(0.7, 0.3)),
                        Row.of(1.0, Vectors.dense(0.65, 0.35)),
                        Row.of(0.0, Vectors.dense(0.8, 0.2)),
                        Row.of(1.0, Vectors.dense(0.9, 0.1)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("label", "rawPrediction");

        // Creates a BinaryClassificationEvaluator object and initializes its parameters.
        BinaryClassificationEvaluator evaluator =
                new BinaryClassificationEvaluator()
                        .setMetricsNames(
                                BinaryClassificationEvaluatorParams.AREA_UNDER_PR,
                                BinaryClassificationEvaluatorParams.KS,
                                BinaryClassificationEvaluatorParams.AREA_UNDER_ROC);

        // Uses the BinaryClassificationEvaluator object for evaluations.
        Table outputTable = evaluator.transform(inputTable)[0];

        // Extracts and displays the results.
        Row evaluationResult = outputTable.execute().collect().next();
        System.out.printf(
                "Area under the precision-recall curve: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_PR));
        System.out.printf(
                "Area under the receiver operating characteristic curve: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC));
        System.out.printf(
                "Kolmogorov-Smirnov value: %s\n",
                evaluationResult.getField(BinaryClassificationEvaluatorParams.KS));
    }
}

4 特征工程

4.1 Binarizer

Binarizer binarizes the columns of continuous features by the given thresholds. The continuous features may be DenseVector, SparseVector, or Numerical Value.

输入列

参数名称类型默认值描述
inputColsNumber/VectornullNumber/Vectors to be binarized.

输出列

参数名称类型默认值描述
outputColsNumber/VectornullBinarized Number/Vectors.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column name.
thresholdsnullDouble[]yesThe thresholds used to binarize continuous features.

代码示例

import org.apache.flink.ml.feature.binarizer.Binarizer;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a Binarizer instance and uses it for feature engineering. */
public class BinarizerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(
                                1,
                                Vectors.dense(1, 2),
                                Vectors.sparse(
                                        17, new int[] {0, 3, 9}, new double[] {1.0, 2.0, 7.0})),
                        Row.of(
                                2,
                                Vectors.dense(2, 1),
                                Vectors.sparse(
                                        17, new int[] {0, 2, 14}, new double[] {5.0, 4.0, 1.0})),
                        Row.of(
                                3,
                                Vectors.dense(5, 18),
                                Vectors.sparse(
                                        17, new int[] {0, 11, 12}, new double[] {2.0, 4.0, 4.0})));

        Table inputTable = tEnv.fromDataStream(inputStream).as("f0", "f1", "f2");

        // Creates a Binarizer object and initializes its parameters.
        Binarizer binarizer =
                new Binarizer()
                        .setInputCols("f0", "f1", "f2")
                        .setOutputCols("of0", "of1", "of2")
                        .setThresholds(0.0, 0.0, 0.0);

        // Transforms input data.
        Table outputTable = binarizer.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Object[] inputValues = new Object[binarizer.getInputCols().length];
            Object[] outputValues = new Object[binarizer.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = row.getField(binarizer.getInputCols()[i]);
                outputValues[i] = row.getField(binarizer.getOutputCols()[i]);
            }

            System.out.printf(
                    "Input Values: %s\tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.2 Bucketizer

Bucketizer is an algorithm that maps multiple columns of continuous features to multiple columns of discrete features, i.e., buckets indices. The indices are in [0, numSplitsInThisColumn - 1].

输入列

参数名称类型默认值描述
inputColsNumbernullContinuous features to be bucketized.

输出列

参数名称类型默认值描述
outputColsDoublenullDiscretized features.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column names.
handleInvalid“error”StringnoStrategy to handle invalid entries. Supported values: ‘error’, ‘skip’, ‘keep’.
splitsArraynullDouble[][]yesArray of split points for mapping continuous features into buckets.

示例代码

import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.feature.bucketizer.Bucketizer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a Bucketizer instance and uses it for feature engineering. */
public class BucketizerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream = env.fromElements(Row.of(-0.5, 0.0, 1.0, 0.0));
        Table inputTable = tEnv.fromDataStream(inputStream).as("f1", "f2", "f3", "f4");

        // Creates a Bucketizer object and initializes its parameters.
        Double[][] splitsArray =
                new Double[][] {
                    new Double[] {-0.5, 0.0, 0.5},
                    new Double[] {-1.0, 0.0, 2.0},
                    new Double[] {Double.NEGATIVE_INFINITY, 10.0, Double.POSITIVE_INFINITY},
                    new Double[] {Double.NEGATIVE_INFINITY, 1.5, Double.POSITIVE_INFINITY}
                };
        Bucketizer bucketizer =
                new Bucketizer()
                        .setInputCols("f1", "f2", "f3", "f4")
                        .setOutputCols("o1", "o2", "o3", "o4")
                        .setSplitsArray(splitsArray)
                        .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);

        // Uses the Bucketizer object for feature transformations.
        Table outputTable = bucketizer.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            double[] inputValues = new double[bucketizer.getInputCols().length];
            double[] outputValues = new double[bucketizer.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = (double) row.getField(bucketizer.getInputCols()[i]);
                outputValues[i] = (double) row.getField(bucketizer.getOutputCols()[i]);
            }

            System.out.printf(
                    "Input Values: %s\tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.3 CountVectorizer

CountVectorizer is an algorithm that converts a collection of text documents to vectors of token counts. When an a-priori dictionary is not available, CountVectorizer can be used as an estimator to extract the vocabulary, and generates a CountVectorizerModel. The model produces sparse representations for the documents over the vocabulary, which can then be passed to other algorithms like LDA.

输入列

参数名称类型默认值描述
inputColsString[]“input”Input string array.

输出列

参数名称类型默认值描述
inputColsSparseVector“output”Vector of token counts.

算法参数

Below are the parameters required by CountVectorizerModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
minTF1.0DoublenoFilter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document’s token count).
binaryfalseBooleannoBinary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1.0.

CountVectorizer needs parameters above and also below.

参数默认值类型是否必须描述
vocabularySize2^18IntegernoMax size of the vocabulary. CountVectorizer will build a vocabulary that only considers the top vocabulary size terms ordered by term frequency across the corpus.
minDF1.0DoublenoSpecifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents.
maxDF2^63 - 1DoublenoSpecifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in.

示例代码

import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/**
 * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering.
 */
public class CountVectorizerExample {

    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> dataStream =
                env.fromElements(
                        Row.of((Object) new String[] {"a", "c", "b", "c"}),
                        Row.of((Object) new String[] {"c", "d", "e"}),
                        Row.of((Object) new String[] {"a", "b", "c"}),
                        Row.of((Object) new String[] {"e", "f"}),
                        Row.of((Object) new String[] {"a", "c", "a"}));
        Table inputTable = tEnv.fromDataStream(dataStream).as("input");

        // Creates an CountVectorizer object and initialize its parameters
        CountVectorizer countVectorizer = new CountVectorizer();

        // Trains the CountVectorizer model
        CountVectorizerModel model = countVectorizer.fit(inputTable);

        // Uses the CountVectorizer model for predictions.
        Table outputTable = model.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol());
            SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol());
            System.out.printf(
                    "Input Value: %-15s \tOutput Value: %s\n",
                    Arrays.toString(inputValue), outputValue.toString());
        }
    }
}

4.4 DCT

DCT is a Transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performed on the input vector. It returns a real vector of the same length representing the DCT. The return vector is scaled such that the transform matrix is unitary (aka scaled DCT-II).

输入列

参数名称类型默认值描述
inputColsVector“input”Input vector to be cosine transformed.

输出列

参数名称类型默认值描述
inputColsVector“output”Cosine transformed output vector.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
inversefalseBooleannoWhether to perform the inverse DCT (true) or forward DCT (false).

代码示例

import org.apache.flink.ml.feature.dct.DCT;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

/** Simple program that creates a DCT instance and uses it for feature engineering. */
public class DCTExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        List<Vector> inputData =
                Arrays.asList(
                        Vectors.dense(1.0, 1.0, 1.0, 1.0), Vectors.dense(1.0, 0.0, -1.0, 0.0));
        Table inputTable = tEnv.fromDataStream(env.fromCollection(inputData)).as("input");

        // Creates a DCT object and initializes its parameters.
        DCT dct = new DCT();

        // Uses the DCT object for feature transformations.
        Table outputTable = dct.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Vector inputValue = row.getFieldAs(dct.getInputCol());
            Vector outputValue = row.getFieldAs(dct.getOutputCol());

            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.5 ElementwiseProduct

ElementwiseProduct multiplies each input vector with a given scaling vector using Hadamard product. If the size of the input vector does not equal the size of the scaling vector, the transformer will throw an IllegalArgumentException.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
scalingVecnullStringyesThe scaling vector.

代码示例

import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/**
 * Simple program that creates an ElementwiseProduct instance and uses it for feature engineering.
 */
public class ElementwiseProductExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(0, Vectors.dense(1.1, 3.2)), Row.of(1, Vectors.dense(2.1, 3.1)));

        Table inputTable = tEnv.fromDataStream(inputStream).as("id", "vec");

        // Creates an ElementwiseProduct object and initializes its parameters.
        ElementwiseProduct elementwiseProduct =
                new ElementwiseProduct()
                        .setInputCol("vec")
                        .setOutputCol("outputVec")
                        .setScalingVec(Vectors.dense(1.1, 1.1));

        // Transforms input data.
        Table outputTable = elementwiseProduct.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            Vector inputValue = (Vector) row.getField(elementwiseProduct.getInputCol());
            Vector outputValue = (Vector) row.getField(elementwiseProduct.getOutputCol());
            System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.6 FeatureHasher

FeatureHasher transforms a set of categorical or numerical features into a sparse vector of a specified dimension. The rules of hashing categorical columns and numerical columns are as follows:

  • For numerical columns, the index of this feature in the output vector is the hash value of the column name and its correponding value is the same as the input.
  • For categorical columns, the index of this feature in the output vector is the hash value of the string “column_name=value” and the corresponding value is 1.0.

If multiple features are projected into the same column, the output values are accumulated. For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.

输入列

参数名称类型默认值描述
inputColsNumber/String/BooleannullColumns to be hashed.

输出列

参数名称类型默认值描述
inputColsVector“output”Output vector.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
inputCols“output”StringnoOutput column name.
categoricalCols[]String[]noCategorical column names.
numFeatures262144IntegernoThe number of features.

代码示例

import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a FeatureHasher instance and uses it for feature engineering. */
public class FeatureHasherExample {
    public static void main(String[] args) {

        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> dataStream =
                env.fromCollection(
                        Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)));
        Table inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");

        // Creates a FeatureHasher object and initializes its parameters.
        FeatureHasher featureHash =
                new FeatureHasher()
                        .setInputCols("f0", "f1", "f2")
                        .setCategoricalCols("f0", "f2")
                        .setOutputCol("vec")
                        .setNumFeatures(1000);

        // Uses the FeatureHasher object for feature transformations.
        Table outputTable = featureHash.transform(inputDataTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Object[] inputValues = new Object[featureHash.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = row.getField(featureHash.getInputCols()[i]);
            }
            Vector outputValue = (Vector) row.getField(featureHash.getOutputCol());

            System.out.printf(
                    "Input Values: %s \tOutput Value: %s\n",
                    Arrays.toString(inputValues), outputValue);
        }
    }
}

4.7 HashingTF

HashingTF maps a sequence of terms(strings, numbers, booleans) to a sparse vector with a specified dimension using the hashing trick. If multiple features are projected into the same column, the output values are accumulated by default.

输入列

参数名称类型默认值描述
inputColsList/Array of primitive data types or strings“input”Input sequence of terms.

输出列

参数名称类型默认值描述
inputColsSparseVector“output”Output sparse vector.

算法参数

参数默认值类型是否必须描述
binaryfalseBooleannoWhether each dimension of the output vector is binary or not.
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
numFeatures262144IntegernoThe number of features. It will be the length of the output vector.

代码示例

import org.apache.flink.ml.feature.hashingtf.HashingTF;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

/** Simple program that creates a HashingTF instance and uses it for feature engineering. */
public class HashingTFExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of(
					Arrays.asList(
						"HashingTFTest", "Hashing", "Term", "Frequency", "Test")),
				Row.of(
					Arrays.asList(
						"HashingTFTest", "Hashing", "Hashing", "Test", "Test")));

		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates a HashingTF object and initializes its parameters.
		HashingTF hashingTF =
			new HashingTF().setInputCol("input").setOutputCol("output").setNumFeatures(128);

		// Uses the HashingTF object for feature transformations.
		Table outputTable = hashingTF.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			List<Object> inputValue = (List<Object>) row.getField(hashingTF.getInputCol());
			SparseVector outputValue = (SparseVector) row.getField(hashingTF.getOutputCol());

			System.out.printf(
				"Input Value: %s \tOutput Value: %s\n",
				Arrays.toString(inputValue.stream().toArray()), outputValue);
		}
	}
}

4.8 IDF

IDF computes the inverse document frequency (IDF) for the input documents. IDF is computed following idf = log((m + 1) / (d(t) + 1)), where m is the total number of documents and d(t) is the number of documents that contains t.

IDFModel further uses the computed inverse document frequency to compute tf-idf.

输入列

参数名称类型默认值描述
inputColsVector“input”Input documents.

输出列

参数名称类型默认值描述
inputColsVector“output”Tf-idf values of the input.

算法参数

Below are the parameters required by IDFModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

IDF needs parameters above and also below.

参数默认值类型是否必须描述
minDocFreq0IntegernoMinimum number of documents that a term should appear for filtering.

代码示例

import org.apache.flink.ml.feature.idf.IDF;
import org.apache.flink.ml.feature.idf.IDFModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains an IDF model and uses it for feature engineering. */
public class IDFExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of(Vectors.dense(0, 1, 0, 2)),
				Row.of(Vectors.dense(0, 1, 2, 3)),
				Row.of(Vectors.dense(0, 1, 0, 0)));

		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates an IDF object and initializes its parameters.
		IDF idf = new IDF().setMinDocFreq(2);

		// Trains the IDF Model.
		IDFModel model = idf.fit(inputTable);

		// Uses the IDF Model for predictions.
		Table outputTable = model.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();
			DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol());
			DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol());
			System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
		}
	}
}

4.9 Imputer

The imputer for completing missing values of the input columns.

Missing values can be imputed using the statistics(mean, median or most frequent) of each column in which the missing values are located. The input columns should be of numeric type.

Note The mean/median/most frequent value is computed after filtering out missing values and null values, null values are always treated as missing, and so are also imputed.

Note The parameter relativeError is only effective when the strategy is median.

输入列

参数名称类型默认值描述
inputColsNumbernullFeatures to be imputed.

输出列

参数名称类型默认值描述
outputColsDoublenullImputed features.

算法参数

Below are the parameters required by ImputerModel.

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column names.
missingValueDouble.NaNDoublenoThe placeholder for the missing values. All occurrences of missing values will be imputed.

Imputer needs parameters above and also below.

参数默认值类型是否必须描述
strategy“mean”StringnoThe imputation strategy. Supported values: ‘mean’, ‘median’, ‘most_frequent’.
relativeError0.001DoublenoThe relative target precision for the approximate quantile algorithm.

代码示例

import org.apache.flink.ml.feature.imputer.Imputer;
import org.apache.flink.ml.feature.imputer.ImputerModel;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that trains a {@link Imputer} model and uses it for feature engineering. */
public class ImputerExample {

    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(|Double|.NaN, 9.0),
                        Row.of(1.0, 9.0),
                        Row.of(1.5, 9.0),
                        Row.of(2.5, |Double|.NaN),
                        Row.of(5.0, 5.0),
                        Row.of(5.0, 4.0));
        Table trainTable = tEnv.fromDataStream(trainStream).as("input1", "input2");

        // Creates an Imputer object and initialize its parameters
        Imputer imputer =
                new Imputer()
                        .setInputCols("input1", "input2")
                        .setOutputCols("output1", "output2")
                        .setStrategy("mean")
                        .setMissingValue(|Double|.NaN);

        // Trains the Imputer model.
        ImputerModel model = imputer.fit(trainTable);

        // Uses the Imputer model for predictions.
        Table outputTable = model.transform(trainTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            double[] inputValues = new double[imputer.getInputCols().length];
            double[] outputValues = new double[imputer.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = (double) row.getField(imputer.getInputCols()[i]);
                outputValues[i] = (double) row.getField(imputer.getOutputCols()[i]);
            }
            System.out.printf(
                    "Input Values: %s\tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.10 IndexToString

IndexToStringModel transforms input index column(s) to string column(s) using the model data computed by StringIndexer. It is a reverse operation of StringIndexerModel.

输入列

参数名称类型默认值描述
inputColsIntegernullIndices to be transformed to string.

输出列

参数名称类型默认值描述
outputColsStringnullTransformed strings.

算法参数

Below are the parameters required by StringIndexerModel.

参数默认值类型是否必须描述
inputColsnullStringyesInput column names.
outputColsnullStringyesOutput column names.

代码示例

import org.apache.flink.ml.feature.stringindexer.IndexToStringModel;
import org.apache.flink.ml.feature.stringindexer.StringIndexerModelData;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/**
 * Simple program that creates an IndexToStringModelExample instance and uses it for feature
 * engineering.
 */
public class IndexToStringModelExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Creates model data for IndexToStringModel.
        StringIndexerModelData modelData =
                new StringIndexerModelData(
                        new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}});
        Table modelTable = tEnv.fromDataStream(env.fromElements(modelData)).as("stringArrays");

        // Generates input data.
        DataStream<Row> predictStream = env.fromElements(Row.of(0, 3), Row.of(1, 2));
        Table predictTable = tEnv.fromDataStream(predictStream).as("inputCol1", "inputCol2");

        // Creates an indexToStringModel object and initializes its parameters.
        IndexToStringModel indexToStringModel =
                new IndexToStringModel()
                        .setInputCols("inputCol1", "inputCol2")
                        .setOutputCols("outputCol1", "outputCol2")
                        .setModelData(modelTable);

        // Uses the indexToStringModel object for feature transformations.
        Table outputTable = indexToStringModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            int[] inputValues = new int[indexToStringModel.getInputCols().length];
            String[] outputValues = new String[indexToStringModel.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = (int) row.getField(indexToStringModel.getInputCols()[i]);
                outputValues[i] = (String) row.getField(indexToStringModel.getOutputCols()[i]);
            }

            System.out.printf(
                    "Input Values: %s \tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.11 Interaction

Interaction takes vector or numerical columns, and generates a single vector column that contains the product of all combinations of one value from each input column.

For example, when the input feature values are Double(2) and Vector(3, 4), the output would be Vector(6, 8). When the input feature values are Vector(1, 2) and Vector(3, 4), the output would be Vector(3, 4, 6, 8). If you change the position of these two input Vectors, the output would be Vector(3, 6, 4, 8).

输入列

参数名称类型默认值描述
inputColsVectornullColumns to be interacted.

输出列

参数名称类型默认值描述
inputColsVector“output”Interacted vector.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
inputCols“output”StringnoOutput column name.

代码示例

import org.apache.flink.ml.feature.interaction.Interaction;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates an Interaction instance and uses it for feature engineering. */
public class InteractionExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(0, Vectors.dense(1.1, 3.2), Vectors.dense(2, 3)),
                        Row.of(1, Vectors.dense(2.1, 3.1), Vectors.dense(1, 3)));

        Table inputTable = tEnv.fromDataStream(inputStream).as("f0", "f1", "f2");

        // Creates an Interaction object and initializes its parameters.
        Interaction interaction =
                new Interaction().setInputCols("f0", "f1", "f2").setOutputCol("outputVec");

        // Transforms input data.
        Table outputTable = interaction.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            Object[] inputValues = new Object[interaction.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = row.getField(interaction.getInputCols()[i]);
            }
            Vector outputValue = (Vector) row.getField(interaction.getOutputCol());
            System.out.printf(
                    "Input Values: %s \tOutput Value: %s\n",
                    Arrays.toString(inputValues), outputValue);
        }
    }
}

4.12 KBinsDiscretizer

KBinsDiscretizer is an algorithm that implements discretization (also known as quantization or binning) to transform continuous features into discrete ones. The output values are in [0, numBins).

输入列

参数名称类型默认值描述
inputColsDenseVector“input”Vectors to be discretized.

输出列

参数名称类型默认值描述
inputColsDenseVector“output”Discretized vectors.

算法参数

Below are the parameters required by KBinsDiscretizerModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

KBinsDiscretizer needs parameters above and also below.

参数默认值类型是否必须描述
strategy“quantile”StringnoStrategy used to define the width of the bin. Supported values: ‘uniform’, ‘quantile’, ‘kmeans’.
numBins5IntegernoNumber of bins to produce.
subSamples200000IntegernoMaximum number of samples used to fit the model.

代码示例

import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a KBinsDiscretizer model and uses it for feature engineering. */
public class KBinsDiscretizerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(1, 10, 0)),
                        Row.of(Vectors.dense(1, 10, 0)),
                        Row.of(Vectors.dense(1, 10, 0)),
                        Row.of(Vectors.dense(4, 10, 0)),
                        Row.of(Vectors.dense(5, 10, 0)),
                        Row.of(Vectors.dense(6, 10, 0)),
                        Row.of(Vectors.dense(7, 10, 0)),
                        Row.of(Vectors.dense(10, 10, 0)),
                        Row.of(Vectors.dense(13, 10, 3)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("input");

        // Creates a KBinsDiscretizer object and initializes its parameters.
        KBinsDiscretizer kBinsDiscretizer =
                new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);

        // Trains the KBinsDiscretizer Model.
        KBinsDiscretizerModel model = kBinsDiscretizer.fit(inputTable);

        // Uses the KBinsDiscretizer Model for predictions.
        Table outputTable = model.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(kBinsDiscretizer.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(kBinsDiscretizer.getOutputCol());
            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.13 MaxAbsScaler

MaxAbsScaler is an algorithm rescales feature values to the range [-1, 1] by dividing through the largest maximum absolute value in each feature. It does not shift/center the data and thus does not destroy any sparsity.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

代码示例

import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler;
import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a MaxAbsScaler model and uses it for feature engineering. */
public class MaxAbsScalerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(0.0, 3.0)),
                        Row.of(Vectors.dense(2.1, 0.0)),
                        Row.of(Vectors.dense(4.1, 5.1)),
                        Row.of(Vectors.dense(6.1, 8.1)),
                        Row.of(Vectors.dense(200, 400)));
        Table trainTable = tEnv.fromDataStream(trainStream).as("input");

        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(150.0, 90.0)),
                        Row.of(Vectors.dense(50.0, 40.0)),
                        Row.of(Vectors.dense(100.0, 50.0)));
        Table predictTable = tEnv.fromDataStream(predictStream).as("input");

        // Creates a MaxAbsScaler object and initializes its parameters.
        MaxAbsScaler maxAbsScaler = new MaxAbsScaler();

        // Trains the MaxAbsScaler Model.
        MaxAbsScalerModel maxAbsScalerModel = maxAbsScaler.fit(trainTable);

        // Uses the MaxAbsScaler Model for predictions.
        Table outputTable = maxAbsScalerModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(maxAbsScaler.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(maxAbsScaler.getOutputCol());
            System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.14 MinHashLSH

MinHashLSH is a Locality Sensitive Hashing (LSH) scheme for Jaccard distance metric. The input features are sets of natural numbers represented as non-zero indices of vectors, either dense vectors or sparse vectors. Typically, sparse vectors are more efficient.

In addition to transforming input feature vectors to multiple hash values, the MinHashLSH model also supports approximate nearest neighbors search within a dataset regarding a key vector and approximate similarity join between two datasets.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be mapped.

输出列

参数名称类型默认值描述
inputColsDenseVector[]“output”Hash values.

算法参数

Below are the parameters required by MinHashLSHModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

MinHashLSH needs parameters above and also below.

参数默认值类型是否必须描述
seednullLongnoThe random seed.
numHashTables1IntegernoDefault number of hash tables, for OR-amplification.
numHashFunctionPerTable1IntegernoDefault number of hash functions per table, for AND-amplification.

代码示例

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.ml.feature.lsh.MinHashLSH;
import org.apache.flink.ml.feature.lsh.MinHashLSHModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;

import org.apache.commons.collections.IteratorUtils;

import java.util.Arrays;
import java.util.List;

import static org.apache.flink.table.api.Expressions.$;

/**
 * Simple program that trains a MinHashLSH model and uses it for approximate nearest neighbors and
 * similarity join.
 */
public class MinHashLSHExample {
    public static void main(String[] args) throws Exception {

        // Creates a new StreamExecutionEnvironment.
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        // Creates a StreamTableEnvironment.
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates two datasets.
        Table dataA =
                tEnv.fromDataStream(
                        env.fromCollection(
                                Arrays.asList(
                                        Row.of(
                                                0,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {0, 1, 2},
                                                        new double[] {1., 1., 1.})),
                                        Row.of(
                                                1,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {2, 3, 4},
                                                        new double[] {1., 1., 1.})),
                                        Row.of(
                                                2,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {0, 2, 4},
                                                        new double[] {1., 1., 1.}))),
                                Types.ROW_NAMED(
                                        new String[] {"id", "vec"},
                                        Types.INT,
                                        TypeInformation.of(SparseVector.class))));

        Table dataB =
                tEnv.fromDataStream(
                        env.fromCollection(
                                Arrays.asList(
                                        Row.of(
                                                3,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {1, 3, 5},
                                                        new double[] {1., 1., 1.})),
                                        Row.of(
                                                4,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {2, 3, 5},
                                                        new double[] {1., 1., 1.})),
                                        Row.of(
                                                5,
                                                Vectors.sparse(
                                                        6,
                                                        new int[] {1, 2, 4},
                                                        new double[] {1., 1., 1.}))),
                                Types.ROW_NAMED(
                                        new String[] {"id", "vec"},
                                        Types.INT,
                                        TypeInformation.of(SparseVector.class))));

        // Creates a MinHashLSH estimator object and initializes its parameters.
        MinHashLSH lsh =
                new MinHashLSH()
                        .setInputCol("vec")
                        .setOutputCol("hashes")
                        .setSeed(2022)
                        .setNumHashTables(5);

        // Trains the MinHashLSH model.
        MinHashLSHModel model = lsh.fit(dataA);

        // Uses the MinHashLSH model for transformation.
        Table output = model.transform(dataA)[0];

        // Extracts and displays the results.
        List<String> fieldNames = output.getResolvedSchema().getColumnNames();
        for (Row result :
                (List<Row>) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) {
            Vector inputValue = result.getFieldAs(fieldNames.indexOf(lsh.getInputCol()));
            DenseVector[] outputValue = result.getFieldAs(fieldNames.indexOf(lsh.getOutputCol()));
            System.out.printf(
                    "Vector: %s \tHash values: %s\n", inputValue, Arrays.toString(outputValue));
        }

        // Finds approximate nearest neighbors of the key.
        Vector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1., 1.});
        output = model.approxNearestNeighbors(dataA, key, 2).select($("id"), $("distCol"));
        for (Row result :
                (List<Row>) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) {
            int idValue = result.getFieldAs(fieldNames.indexOf("id"));
            double distValue = result.getFieldAs(result.getArity() - 1);
            System.out.printf("ID: %d \tDistance: %f\n", idValue, distValue);
        }

        // Approximately finds pairs from two datasets with distances smaller than the threshold.
        output = model.approxSimilarityJoin(dataA, dataB, .6, "id");
        for (Row result :
                (List<Row>) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) {
            int idAValue = result.getFieldAs(0);
            int idBValue = result.getFieldAs(1);
            double distValue = result.getFieldAs(2);
            System.out.printf(
                    "ID from left: %d \tID from right: %d \t Distance: %f\n",
                    idAValue, idAValue, distValue);
        }
    }
}

4.15 MinMaxScaler

MinMaxScaler is an algorithm that rescales feature values to a common range [min, max] which defined by user.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
min0.0DoublenoLower bound of the output feature range.
max1.0DoublenoUpper bound of the output feature range.

代码示例

import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler;
import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a MinMaxScaler model and uses it for feature engineering. */
public class MinMaxScalerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(0.0, 3.0)),
                        Row.of(Vectors.dense(2.1, 0.0)),
                        Row.of(Vectors.dense(4.1, 5.1)),
                        Row.of(Vectors.dense(6.1, 8.1)),
                        Row.of(Vectors.dense(200, 400)));
        Table trainTable = tEnv.fromDataStream(trainStream).as("input");

        DataStream<Row> predictStream =
                env.fromElements(
                        Row.of(Vectors.dense(150.0, 90.0)),
                        Row.of(Vectors.dense(50.0, 40.0)),
                        Row.of(Vectors.dense(100.0, 50.0)));
        Table predictTable = tEnv.fromDataStream(predictStream).as("input");

        // Creates a MinMaxScaler object and initializes its parameters.
        MinMaxScaler minMaxScaler = new MinMaxScaler();

        // Trains the MinMaxScaler Model.
        MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainTable);

        // Uses the MinMaxScaler Model for predictions.
        Table outputTable = minMaxScalerModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(minMaxScaler.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(minMaxScaler.getOutputCol());
            System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.16 NGram

NGram converts the input string array into an array of n-grams, where each n-gram is represented by a space-separated string of words. If the length of the input array is less than n, no n-grams are returned.

输入列

参数名称类型默认值描述
inputColsString[]“input”Input string array.

输出列

参数名称类型默认值描述
inputColsString[]“output”N-grams.

算法参数

参数默认值类型是否必须描述
n2IntegernoNumber of elements per n-gram (>=1).
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

代码示例

import org.apache.flink.ml.feature.ngram.NGram;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates an NGram instance and uses it for feature engineering. */
public class NGramExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of((Object) new String[0]),
				Row.of((Object) new String[] {"a", "b", "c"}),
				Row.of((Object) new String[] {"a", "b", "c", "d"}));
		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates an NGram object and initializes its parameters.
		NGram nGram = new NGram().setN(2).setInputCol("input").setOutputCol("output");

		// Uses the NGram object for feature transformations.
		Table outputTable = nGram.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			String[] inputValue = (String[]) row.getField(nGram.getInputCol());
			String[] outputValue = (String[]) row.getField(nGram.getOutputCol());

			System.out.printf(
				"Input Value: %s \tOutput Value: %s\n",
				Arrays.toString(inputValue), Arrays.toString(outputValue));
		}
	}
}

4.17 Normalizer

A Transformer that normalizes a vector to have unit norm using the given p-norm.

输入列

参数名称类型默认值描述
inputColsVector“input”Vectors to be normalized.

输出列

参数名称类型默认值描述
inputColsVector“output”Normalized vectors.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
p2.0DoublenoThe p norm value.

代码示例

import org.apache.flink.ml.feature.normalizer.Normalizer;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates a Normalizer instance and uses it for feature engineering. */
public class NormalizerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of(Vectors.dense(2.1, 3.1, 1.2, 3.1, 4.6)),
				Row.of(Vectors.dense(1.2, 3.1, 4.6, 2.1, 3.1)));
		Table inputTable = tEnv.fromDataStream(inputStream).as("inputVec");

		// Creates a Normalizer object and initializes its parameters.
		Normalizer normalizer =
			new Normalizer().setInputCol("inputVec").setP(3.0).setOutputCol("outputVec");

		// Uses the Normalizer object for feature transformations.
		Table outputTable = normalizer.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			Vector inputValue = (Vector) row.getField(normalizer.getInputCol());

			Vector outputValue = (Vector) row.getField(normalizer.getOutputCol());

			System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue);
		}
	}
}

4.18 OneHotEncoder

OneHotEncoder maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms that expect continuous features, such as Logistic Regression, to use categorical features.

OneHotEncoder can transform multiple columns, returning a one-hot-encoded output vector column for each input column.

输入列

参数名称类型默认值描述
inputColsIntegernullLabel index.

输出列

参数名称类型默认值描述
outputColsVectornullEncoded binary vector.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column names.
handleInvalid“error”StringnoStrategy to handle invalid entries. Supported values: ‘error’, ‘skip’, ‘keep’.
dropLasttrueBooleannoWhether to drop the last category.

代码示例

import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder;
import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a OneHotEncoder model and uses it for feature engineering. */
public class OneHotEncoderExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0));
        Table trainTable = tEnv.fromDataStream(trainStream).as("input");

        DataStream<Row> predictStream = env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0));
        Table predictTable = tEnv.fromDataStream(predictStream).as("input");

        // Creates a OneHotEncoder object and initializes its parameters.
        OneHotEncoder oneHotEncoder =
                new OneHotEncoder().setInputCols("input").setOutputCols("output");

        // Trains the OneHotEncoder Model.
        OneHotEncoderModel model = oneHotEncoder.fit(trainTable);

        // Uses the OneHotEncoder Model for predictions.
        Table outputTable = model.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            Double inputValue = (Double) row.getField(oneHotEncoder.getInputCols()[0]);
            SparseVector outputValue =
                    (SparseVector) row.getField(oneHotEncoder.getOutputCols()[0]);
            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.19 OnlineStandardScaler

An Estimator which implements the online standard scaling algorithm, which is the online version of StandardScaler.

OnlineStandardScaler splits the input data by the user-specified window strategy. For each window, it computes the mean and standard deviation using the data seen so far (i.e., not only the data in the current window, but also the history data). The model data generated by OnlineStandardScaler is a model stream. There is one model data for each window.

During the inference phase (i.e., using OnlineStandardScalerModel for prediction), users could output the model version that is used for predicting each data point. Moreover,

When the train data and test data both contain event time, users could specify the maximum difference between the timestamps of the input and model data, which enforces to use a relatively fresh model for prediction.
Otherwise, the prediction process always uses the current model data for prediction.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.
modelVersionColStringversionThe name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer.

算法参数

Below are the parameters required by OnlineStandardScalerModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
withMeanfalseBooleannoWhether centers the data with mean before scaling.
withStdtrueBooleannoWhether scales the data with standard deviation.
modelVersionColversionStringnoThe name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer.
maxAllowedModelDelayMs0LLongnoThe maximum difference allowed between the timestamps of the input record and the model data that is used to predict that input record. This param only works when the input contains event time.
OnlineStandardScaler needs parameters above and also below.
参数默认值类型是否必须描述
windowsGlobalWindows.getInstance()WindowsnoWindowing strategy that determines how to create mini-batches from input data.

代码示例

import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */
public class OnlineStandardScalerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		List<Row> inputData =
			Arrays.asList(
				Row.of(0L, Vectors.dense(-2.5, 9, 1)),
				Row.of(1000L, Vectors.dense(1.4, -5, 1)),
				Row.of(2000L, Vectors.dense(2, -1, -2)),
				Row.of(6000L, Vectors.dense(0.7, 3, 1)),
				Row.of(7000L, Vectors.dense(0, 1, 1)),
				Row.of(8000L, Vectors.dense(0.5, 0, -2)),
				Row.of(9000L, Vectors.dense(0.4, 1, 1)),
				Row.of(10000L, Vectors.dense(0.3, 2, 1)),
				Row.of(11000L, Vectors.dense(0.5, 1, -2)));

		DataStream<Row> inputStream = env.fromCollection(inputData);

		DataStream<Row> inputStreamWithEventTime =
			inputStream.assignTimestampsAndWatermarks(
				WatermarkStrategy.<Row>forMonotonousTimestamps()
					.withTimestampAssigner(
						(SerializableTimestampAssigner<Row>)
							(element, recordTimestamp) ->
								element.getFieldAs(0)));

		Table inputTable =
			tEnv.fromDataStream(
					inputStreamWithEventTime,
					Schema.newBuilder()
						.column("f0", DataTypes.BIGINT())
						.column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE))
						.columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)")
						.watermark("rowtime", "SOURCE_WATERMARK()")
						.build())
				.as("id", "input");

		// Creates an OnlineStandardScaler object and initializes its parameters.
		long windowSizeMs = 3000;
		OnlineStandardScaler onlineStandardScaler =
			new OnlineStandardScaler()
				.setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs)));

		// Trains the OnlineStandardScaler Model.
		OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable);

		// Uses the OnlineStandardScaler Model for predictions.
		Table outputTable = model.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();
			DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol());
			DenseVector outputValue =
				(DenseVector) row.getField(onlineStandardScaler.getOutputCol());
			long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol());
			System.out.printf(
				"Input Value: %s\tOutput Value: %s\tModel Version: %s\n",
				inputValue, outputValue, modelVersion);
		}
	}
}

4.20 PolynomialExpansion

A Transformer that expands the input vectors in polynomial space.

Take a 2-dimension vector as an example: (x, y), if we want to expand it with degree 2, then we get (x, x * x, y, x * y, y * y).

For more information about the polynomial expansion, see http://en.wikipedia.org/wiki/Polynomial_expansion.

输入列

参数名称类型默认值描述
inputColsVector“input”Vectors to be expanded.

输出列

参数名称类型默认值描述
inputColsVector“output”Expanded vectors.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
degree2IntegernoDegree of the polynomial expansion.

代码示例

import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates a PolynomialExpansion instance and uses it for feature engineering. */
public class PolynomialExpansionExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of(Vectors.dense(2.1, 3.1, 1.2)),
				Row.of(Vectors.dense(1.2, 3.1, 4.6)));
		Table inputTable = tEnv.fromDataStream(inputStream).as("inputVec");

		// Creates a PolynomialExpansion object and initializes its parameters.
		PolynomialExpansion polynomialExpansion =
			new PolynomialExpansion().setInputCol("inputVec").setDegree(2).setOutputCol("outputVec");

		// Uses the PolynomialExpansion object for feature transformations.
		Table outputTable = polynomialExpansion.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			Vector inputValue = (Vector) row.getField(polynomialExpansion.getInputCol());

			Vector outputValue = (Vector) row.getField(polynomialExpansion.getOutputCol());

			System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue);
		}
	}
}

4.21 RandomSplitter

An AlgoOperator which splits a table into N tables according to the given weights.

算法参数

参数默认值类型是否必须描述
weights[1.0, 1.0]Double[]no
seednullLongnoThe random seed. This parameter guarantees reproduciable output only when the paralleism is unchanged and each worker reads the same data in the same order.

代码示例

import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates a RandomSplitter instance and uses it for data splitting. */
public class RandomSplitterExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(
				Row.of(1, 10, 0),
				Row.of(1, 10, 0),
				Row.of(1, 10, 0),
				Row.of(4, 10, 0),
				Row.of(5, 10, 0),
				Row.of(6, 10, 0),
				Row.of(7, 10, 0),
				Row.of(10, 10, 0),
				Row.of(13, 10, 3));
		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates a RandomSplitter object and initializes its parameters.
		RandomSplitter splitter = new RandomSplitter().setWeights(4.0, 6.0);

		// Uses the RandomSplitter to split inputData.
		Table[] outputTables = splitter.transform(inputTable);

		// Extracts and displays the results.
		System.out.println("Split Result 1 (40%)");
		for (CloseableIterator<Row> it = outputTables[0].execute().collect(); it.hasNext(); ) {
			System.out.printf("%s\n", it.next());
		}
		System.out.println("Split Result 2 (60%)");
		for (CloseableIterator<Row> it = outputTables[1].execute().collect(); it.hasNext(); ) {
			System.out.printf("%s\n", it.next());
		}
	}
}

4.22 RegexTokenizer

RegexTokenizer is an algorithm that converts the input string to lowercase and then splits it by white spaces based on regex.

输入列

参数名称类型默认值描述
inputColsString“input”Strings to be tokenized.

输出列

参数名称类型默认值描述
inputColsString[]“output”Tokenized Strings.

算法参数

参数默认值类型是否必须描述
minTokenLength1IntegernoMinimum token length.
gapstrueBooleannoSet regex to match gaps or tokens.
pattern“\s+”StringnoRegex pattern used for tokenizing.
toLowercasetrueBooleannoWhether to convert all characters to lowercase before tokenizing.
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

代码示例

import org.apache.flink.ml.feature.regextokenizer.RegexTokenizer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a RegexTokenizer instance and uses it for feature engineering. */
public class RegexTokenizerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(Row.of("Test for tokenization."), Row.of("Te,st. punct"));
		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates a RegexTokenizer object and initializes its parameters.
		RegexTokenizer regexTokenizer =
			new RegexTokenizer()
				.setInputCol("input")
				.setOutputCol("output")
				.setPattern("\\w+|\\p{Punct}");

		// Uses the Tokenizer object for feature transformations.
		Table outputTable = regexTokenizer.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			String inputValue = (String) row.getField(regexTokenizer.getInputCol());
			String[] outputValues = (String[]) row.getField(regexTokenizer.getOutputCol());

			System.out.printf(
				"Input Value: %s \tOutput Values: %s\n",
				inputValue, Arrays.toString(outputValues));
		}
	}
}

4.23 RobustScaler

RobustScaler is an algorithm that scales features using statistics that are robust to outliers.

This Scaler removes the median and scales the data according to the quantile range (defaults to IQR: Interquartile Range). The IQR is the range between the 1st quartile (25th quantile) and the 3rd quartile (75th quantile) but can be configured.

Centering and scaling happen independently on each feature by computing the relevant statistics on the samples in the training set. Median and quantile range are then stored to be used on later data using the transform method.

Standardization of a dataset is a common requirement for many machine learning estimators. Typically this is done by removing the mean and scaling to unit variance. However, outliers can often influence the sample mean / variance in a negative way. In such cases, the median and the interquartile range often give better results.

Note that NaN values are ignored in the computation of medians and ranges.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

Below are the parameters required by RobustScalerModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
withCenteringfalseBooleannoWhether to center the data with median before scaling.
withScalingtrueBooleannoWhether to scale the data to quantile range.

RobustScaler needs parameters above and also below.

参数默认值类型是否必须描述
lower0.25DoublenoLower quantile to calculate quantile range.
upper0.75DoublenoUpper quantile to calculate quantile range.
relativeError0.001DoublenoThe relative target precision for the approximate

代码示例

import org.apache.flink.ml.feature.robustscaler.RobustScaler;
import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a {@link RobustScaler} model and uses it for feature selection. */
public class RobustScalerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(1, Vectors.dense(0.0, 0.0)),
                        Row.of(2, Vectors.dense(1.0, -1.0)),
                        Row.of(3, Vectors.dense(2.0, -2.0)),
                        Row.of(4, Vectors.dense(3.0, -3.0)),
                        Row.of(5, Vectors.dense(4.0, -4.0)),
                        Row.of(6, Vectors.dense(5.0, -5.0)),
                        Row.of(7, Vectors.dense(6.0, -6.0)),
                        Row.of(8, Vectors.dense(7.0, -7.0)),
                        Row.of(9, Vectors.dense(8.0, -8.0)));
        Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input");

        // Creates a RobustScaler object and initializes its parameters.
        RobustScaler robustScaler =
                new RobustScaler()
                        .setLower(0.25)
                        .setUpper(0.75)
                        .setRelativeError(0.001)
                        .setWithScaling(true)
                        .setWithCentering(true);

        // Trains the RobustScaler model.
        RobustScalerModel model = robustScaler.fit(trainTable);

        // Uses the RobustScaler model for predictions.
        Table outputTable = model.transform(trainTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(robustScaler.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(robustScaler.getOutputCol());
            System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.24 SQLTransformer

SQLTransformer implements the transformations that are defined by SQL statement.

Currently we only support SQL syntax like SELECT … FROM THIS … where THIS represents the input table and cannot be modified.

The select clause specifies the fields, constants, and expressions to display in the output. Except the cases described in the note section below, it can be any select clause that Flink SQL supports. Users can also use Flink SQL built-in function and UDFs to operate on these selected columns.

For example, SQLTransformer supports statements like:

SELECT a, a + b AS a_b FROM THIS
SELECT a, SQRT(b) AS b_sqrt FROM THIS where a > 5
SELECT a, b, SUM© AS c_sum FROM THIS GROUP BY a, b
Note: This operator only generates append-only/insert-only table as its output. If the output table could possibly contain retract messages(e.g. perform SELECT … FROM THIS GROUP BY … operation on a table in streaming mode), this operator would aggregate all changelogs and only output the final state.

算法参数

参数默认值类型是否必须描述
statementnullStringyesSQL statement.

代码示例

import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.feature.sqltransformer.SQLTransformer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;

import java.util.Arrays;

/** Simple program that creates a SQLTransformer instance and uses it for feature engineering. */
public class SQLTransformerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromCollection(
                        Arrays.asList(Row.of(0, 1.0, 3.0), Row.of(2, 2.0, 5.0)),
                        new RowTypeInfo(Types.INT, Types.DOUBLE, Types.DOUBLE));
        Table inputTable = tEnv.fromDataStream(inputStream).as("id", "v1", "v2");

        // Creates a SQLTransformer object and initializes its parameters.
        SQLTransformer sqlTransformer =
                new SQLTransformer()
                        .setStatement("SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__");

        // Uses the SQLTransformer object for feature transformations.
        Table outputTable = sqlTransformer.transform(inputTable)[0];

        // Extracts and displays the results.
        outputTable.execute().print();
    }
}

4.25 StandardScaler

StandardScaler is an algorithm that standardizes the input features by removing the mean and scaling each dimension to unit variance.

输入列

参数名称类型默认值描述
inputColsVector“input”Features to be scaled.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
withMeanfalseBooleannoWhether centers the data with mean before scaling.
withStdtrueBooleannoWhether scales the data with standard deviation.

代码示例

import org.apache.flink.ml.feature.standardscaler.StandardScaler;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a StandardScaler model and uses it for feature engineering. */
public class StandardScalerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(-2.5, 9, 1)),
                        Row.of(Vectors.dense(1.4, -5, 1)),
                        Row.of(Vectors.dense(2, -1, -2)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("input");

        // Creates a StandardScaler object and initializes its parameters.
        StandardScaler standardScaler = new StandardScaler();

        // Trains the StandardScaler Model.
        StandardScalerModel model = standardScaler.fit(inputTable);

        // Uses the StandardScaler Model for predictions.
        Table outputTable = model.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue = (DenseVector) row.getField(standardScaler.getInputCol());
            DenseVector outputValue = (DenseVector) row.getField(standardScaler.getOutputCol());
            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.26 StopWordsRemover

A feature transformer that filters out stop words from input.

Note: null values from input array are preserved unless adding null to stopWords explicitly.

See Also: Stop words (Wikipedia)

输入列

参数名称类型默认值描述
inputColsString[]nullArrays of strings containing stop words to remove.

输出列

参数名称类型默认值描述
outputColsString[]nullArrays of strings with stop words removed.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column name.
stopWordsStopWordsRemover.loadDefaultStopWords(“english”)String[]noThe words to be filtered out.
caseSensitivefalseBooleannoWhether to do a case-sensitive comparison over the stop words.
localeStopWordsRemover.getDefaultOrUS().toString()StringnoLocale of the input for case insensitive matching. Ignored when caseSensitive is true.

代码示例

import org.apache.flink.ml.feature.stopwordsremover.StopWordsRemover;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a StopWordsRemover instance and uses it for feature engineering. */
public class StopWordsRemoverExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of((Object) new String[] {"test", "test"}),
                        Row.of((Object) new String[] {"a", "b", "c", "d"}),
                        Row.of((Object) new String[] {"a", "the", "an"}),
                        Row.of((Object) new String[] {"A", "The", "AN"}),
                        Row.of((Object) new String[] {null}),
                        Row.of((Object) new String[] {}));
        Table inputTable = tEnv.fromDataStream(inputStream).as("input");

        // Creates a StopWordsRemover object and initializes its parameters.
        StopWordsRemover remover =
                new StopWordsRemover().setInputCols("input").setOutputCols("output");

        // Uses the StopWordsRemover object for feature transformations.
        Table outputTable = remover.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            String[] inputValues = row.getFieldAs("input");
            String[] outputValues = row.getFieldAs("output");

            System.out.printf(
                    "Input Values: %s\tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.27 StringIndexer

StringIndexer maps one or more columns (string/numerical value) of the input to one or more indexed output columns (integer value). The output indices of two data points are the same iff their corresponding input columns are the same. The indices are in [0, numDistinctValuesInThisColumn].

IndexToStringModel transforms input index column(s) to string column(s) using the model data computed by StringIndexer. It is a reverse operation of StringIndexerModel.

输入列

参数名称类型默认值描述
inputColsNumber/StringnullString/Numerical values to be indexed.

输出列

参数名称类型默认值描述
outputColsDoublenullIndices of string/numerical values.

算法参数

Below are the parameters required by StringIndexerModel.

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
outputColsnullString[]yesOutput column names.
handleInvalid“error”StringnoStrategy to handle invalid entries. Supported values: ‘error’, ‘skip’, ‘keep’.

StringIndexer needs parameters above and also below.

参数默认值类型是否必须描述
stringOrderType“arbitrary”StringnoHow to order strings of each column. Supported values: ‘arbitrary’, ‘frequencyDesc’, ‘frequencyAsc’, ‘alphabetDesc’, ‘alphabetAsc’.
MaxIndexNum2147483647IntegernoThe max number of indices for each column. It only works when ‘stringOrderType’ is set as ‘frequencyDesc’.

代码示例

import org.apache.flink.ml.feature.stringindexer.StringIndexer;
import org.apache.flink.ml.feature.stringindexer.StringIndexerModel;
import org.apache.flink.ml.feature.stringindexer.StringIndexerParams;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that trains a StringIndexer model and uses it for feature engineering. */
public class StringIndexerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of("a", 1.0),
                        Row.of("b", 1.0),
                        Row.of("b", 2.0),
                        Row.of("c", 0.0),
                        Row.of("d", 2.0),
                        Row.of("a", 2.0),
                        Row.of("b", 2.0),
                        Row.of("b", -1.0),
                        Row.of("a", -1.0),
                        Row.of("c", -1.0));
        Table trainTable = tEnv.fromDataStream(trainStream).as("inputCol1", "inputCol2");

        DataStream<Row> predictStream =
                env.fromElements(Row.of("a", 2.0), Row.of("b", 1.0), Row.of("c", 2.0));
        Table predictTable = tEnv.fromDataStream(predictStream).as("inputCol1", "inputCol2");

        // Creates a StringIndexer object and initializes its parameters.
        StringIndexer stringIndexer =
                new StringIndexer()
                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
                        .setInputCols("inputCol1", "inputCol2")
                        .setOutputCols("outputCol1", "outputCol2");

        // Trains the StringIndexer Model.
        StringIndexerModel model = stringIndexer.fit(trainTable);

        // Uses the StringIndexer Model for predictions.
        Table outputTable = model.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Object[] inputValues = new Object[stringIndexer.getInputCols().length];
            double[] outputValues = new double[stringIndexer.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = row.getField(stringIndexer.getInputCols()[i]);
                outputValues[i] = (double) row.getField(stringIndexer.getOutputCols()[i]);
            }

            System.out.printf(
                    "Input Values: %s \tOutput Values: %s\n",
                    Arrays.toString(inputValues), Arrays.toString(outputValues));
        }
    }
}

4.28 Tokenizer

Tokenizer is an algorithm that converts the input string to lowercase and then splits it by white spaces.

输入列

参数名称类型默认值描述
inputColsString“input”Strings to be tokenized.

输出列

参数名称类型默认值描述
inputColsString[]“output”Tokenized strings.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

代码示例

import org.apache.flink.ml.feature.tokenizer.Tokenizer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a Tokenizer instance and uses it for feature engineering. */
public class TokenizerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(Row.of("Test for tokenization."), Row.of("Te,st. punct"));
		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates a Tokenizer object and initializes its parameters.
		Tokenizer tokenizer = new Tokenizer().setInputCol("input").setOutputCol("output");

		// Uses the Tokenizer object for feature transformations.
		Table outputTable = tokenizer.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			String inputValue = (String) row.getField(tokenizer.getInputCol());
			String[] outputValues = (String[]) row.getField(tokenizer.getOutputCol());

			System.out.printf(
				"Input Value: %s \tOutput Values: %s\n",
				inputValue, Arrays.toString(outputValues));
		}
	}
}

4.29 UnivariateFeatureSelector

UnivariateFeatureSelector is an algorithm that selects features based on univariate statistical tests against labels.

Currently, Flink supports three UnivariateFeatureSelectors: chi-squared, ANOVA F-test and F-value. User can choose UnivariateFeatureSelector by setting featureType and labelType, and Flink will pick the score function based on the specified featureType and labelType.

The following combination of featureType and labelType are supported:

  • featureType categorical and labelType categorical: Flink uses chi-squared, i.e. chi2 in sklearn.
  • featureType continuous and labelType categorical: Flink uses ANOVA F-test, i.e. f_classif in sklearn.
  • featureType continuous and labelType continuous: Flink uses F-value, i.e. f_regression in sklearn.

UnivariateFeatureSelector supports different selection modes:

  • numTopFeatures: chooses a fixed number of top features according to a hypothesis.
  • percentile: similar to numTopFeatures but chooses a fraction of all features instead of a fixed number.
  • fpr: chooses all features whose p-value are below a threshold, thus controlling the false positive rate of selection.
  • fdr: uses the Benjamini-Hochberg procedure to choose all features whose false discovery rate is below a threshold.
  • fwe: chooses all features whose p-values are below a threshold. The threshold is scaled by 1/numFeatures, thus controlling the family-wise error rate of selection.

By default, the selection mode is numTopFeatures.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColNumber“label”Label of the features.

输出列

参数名称类型默认值描述
inputColsVector“output”Selected features.

算法参数

Below are the parameters required by UnivariateFeatureSelectorModel.

参数默认值类型是否必须描述
featuresCol“features”StringnoFeatures column name.
inputCols“output”StringnoOutput column name.
UnivariateFeatureSelector needs parameters above and also below.
参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
featureTypenullStringyesThe feature type. Supported values: ‘categorical’, ‘continuous’.
labelTypenullStringyesThe label type. Supported values: ‘categorical’, ‘continuous’.
selectionMode“numTopFeatures”StringnoThe feature selection mode. Supported values: ‘numTopFeatures’, ‘percentile’, ‘fpr’, ‘fdr’, ‘fwe’.
selectionThresholdnullNumbernoThe upper bound of the features that selector will select. If not set, it will be replaced with a meaningful value according to different selection modes at runtime. When the mode is numTopFeatures, it will be replaced with 50; when the mode is percentile, it will be replaced with 0.1; otherwise, it will be replaced with 0.05.

代码示例

import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/**
 * Simple program that trains a {@link UnivariateFeatureSelector} model and uses it for feature
 * selection.
 */
public class UnivariateFeatureSelectorExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0),
                        Row.of(Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0),
                        Row.of(Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 1.0),
                        Row.of(Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0),
                        Row.of(Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0),
                        Row.of(Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        // Creates a UnivariateFeatureSelector object and initializes its parameters.
        UnivariateFeatureSelector univariateFeatureSelector =
                new UnivariateFeatureSelector()
                        .setFeaturesCol("features")
                        .setLabelCol("label")
                        .setFeatureType("continuous")
                        .setLabelType("categorical")
                        .setSelectionThreshold(1);

        // Trains the UnivariateFeatureSelector model.
        UnivariateFeatureSelectorModel model = univariateFeatureSelector.fit(trainTable);

        // Uses the UnivariateFeatureSelector model for predictions.
        Table outputTable = model.transform(trainTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue =
                    (DenseVector) row.getField(univariateFeatureSelector.getFeaturesCol());
            DenseVector outputValue =
                    (DenseVector) row.getField(univariateFeatureSelector.getOutputCol());
            System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

4.30 VarianceThresholdSelector

VarianceThresholdSelector is a selector that removes low-variance features. Features with a variance not greater than the varianceThreshold will be removed. If not set, varianceThreshold defaults to 0, which means only features with variance 0 (i.e. features that have the same value in all samples) will be removed.

输入列

参数名称类型默认值描述
inputColsVector“input”Input features.

输出列

参数名称类型默认值描述
inputColsVector“output”Scaled features.

算法参数

Below are the parameters required by VarianceThresholdSelectorModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.

VarianceThresholdSelector needs parameters above and also below.

参数默认值类型是否必须描述
varianceThreshold0.0DoublenoFeatures with a variance not greater than this threshold will be removed.

代码示例

import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector;
import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/**
 * Simple program that trains a {@link VarianceThresholdSelector} model and uses it for feature
 * selection.
 */
public class VarianceThresholdSelectorExample {

    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                env.fromElements(
                        Row.of(1, Vectors.dense(5.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
                        Row.of(2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
                        Row.of(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
                        Row.of(4, Vectors.dense(1.0, 9.0, 8.0, 5.0, 7.0, 4.0)),
                        Row.of(5, Vectors.dense(9.0, 8.0, 6.0, 5.0, 4.0, 4.0)),
                        Row.of(6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0)));
        Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input");

        // Create a VarianceThresholdSelector object and initialize its parameters
        double threshold = 8.0;
        VarianceThresholdSelector varianceThresholdSelector =
                new VarianceThresholdSelector()
                        .setVarianceThreshold(threshold)
                        .setInputCol("input");

        // Train the VarianceThresholdSelector model.
        VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainTable);

        // Uses the VarianceThresholdSelector model for predictions.
        Table outputTable = model.transform(trainTable)[0];

        // Extracts and displays the results.
        System.out.printf("Variance Threshold: %s\n", threshold);
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector inputValue =
                    (DenseVector) row.getField(varianceThresholdSelector.getInputCol());
            DenseVector outputValue =
                    (DenseVector) row.getField(varianceThresholdSelector.getOutputCol());
            System.out.printf("Input Values: %-15s\tOutput Values: %s\n", inputValue, outputValue);
        }
    }
}

4.31 VectorAssembler

A Transformer which combines a given list of input columns into a vector column. Input columns would be numerical or vectors whose sizes are specified by the {@linkINPUT_SIZES} parameter. Invalid input data with null values or values with wrong sizes would be dealt with according to the strategy specified by the {@link HasHandleInvalid} parameter as follows:

keep: If the input column data is null, a vector would be created with the specified size and NaN values. The vector would be used in the assembling process to represent the input column data. If the input column data is a vector, the data would be used in the assembling process even if it has a wrong size.
skip: If the input column data is null or a vector with wrong size, the input row would be filtered out and not be sent to downstream operators.
error: If the input column data is null or a vector with wrong size, an exception would be thrown.

输入列

参数名称类型默认值描述
inputColsNumber/VectornullNumber/Vectors to be assembled.

输出列

参数名称类型默认值描述
inputColsVector“output”Assembled vector.

算法参数

参数默认值类型是否必须描述
inputColsnullString[]yesInput column names.
inputCols“output”StringnoOutput column name.
inputSizesnullInteger[]yes
handleInvalid“error”StringnoStrategy to handle invalid entries. Supported values: ‘error’, ‘skip’, ‘keep’.

代码示例

import org.apache.flink.ml.feature.vectorassembler.VectorAssembler;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a VectorAssembler instance and uses it for feature engineering. */
public class VectorAssemblerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(
                                Vectors.dense(2.1, 3.1),
                                1.0,
                                Vectors.sparse(5, new int[] {3}, new double[] {1.0})),
                        Row.of(
                                Vectors.dense(2.1, 3.1),
                                1.0,
                                Vectors.sparse(
                                        5,
                                        new int[] {4, 2, 3, 1},
                                        new double[] {4.0, 2.0, 3.0, 1.0})));
        Table inputTable = tEnv.fromDataStream(inputStream).as("vec", "num", "sparseVec");

        // Creates a VectorAssembler object and initializes its parameters.
        VectorAssembler vectorAssembler =
                new VectorAssembler()
                        .setInputCols("vec", "num", "sparseVec")
                        .setOutputCol("assembledVec")
                        .setInputSizes(2, 1, 5);

        // Uses the VectorAssembler object for feature transformations.
        Table outputTable = vectorAssembler.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Object[] inputValues = new Object[vectorAssembler.getInputCols().length];
            for (int i = 0; i < inputValues.length; i++) {
                inputValues[i] = row.getField(vectorAssembler.getInputCols()[i]);
            }

            Vector outputValue = (Vector) row.getField(vectorAssembler.getOutputCol());

            System.out.printf(
                    "Input Values: %s \tOutput Value: %s\n",
                    Arrays.toString(inputValues), outputValue);
        }
    }
}

4.32 VectorIndexer

VectorIndexer is an algorithm that implements the vector indexing algorithm. A vector indexer maps each column of the input vector into a continuous/categorical feature. Whether one feature is transformed into a continuous or categorical feature depends on the number of distinct values in this column. If the number of distinct values in one column is greater than a specified parameter (i.e., maxCategories), the corresponding output column is unchanged. Otherwise, it is transformed into a categorical value. For categorical outputs, the indices are in [0, numDistinctValuesInThisColumn].

The output model is organized in ascending order except that 0.0 is always mapped to 0 (for sparsity).

输入列

参数名称类型默认值描述
inputColsVector“input”Vectors to be indexed.

输出列

参数名称类型默认值描述
inputColsVector“output”Indexed vectors.

算法参数

Below are the parameters required by VectorIndexerModel.

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
handleInvalid“error”StringnoStrategy to handle invalid entries. Supported values: ‘error’, ‘skip’, ‘keep’.

VectorIndexer needs parameters above and also below.

参数默认值类型是否必须描述
maxCategories20IntegernoThreshold for the number of values a categorical feature can take (>= 2). If a feature is found to have > maxCategories values, then it is declared continuous.

代码示例

import org.apache.flink.ml.feature.tokenizer.Tokenizer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;

/** Simple program that creates a Tokenizer instance and uses it for feature engineering. */
public class TokenizerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		DataStream<Row> inputStream =
			env.fromElements(Row.of("Test for tokenization."), Row.of("Te,st. punct"));
		Table inputTable = tEnv.fromDataStream(inputStream).as("input");

		// Creates a Tokenizer object and initializes its parameters.
		Tokenizer tokenizer = new Tokenizer().setInputCol("input").setOutputCol("output");

		// Uses the Tokenizer object for feature transformations.
		Table outputTable = tokenizer.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();

			String inputValue = (String) row.getField(tokenizer.getInputCol());
			String[] outputValues = (String[]) row.getField(tokenizer.getOutputCol());

			System.out.printf(
				"Input Value: %s \tOutput Values: %s\n",
				inputValue, Arrays.toString(outputValues));
		}
	}
}

4.33 VectorSlicer

VectorSlicer transforms a vector to a new feature, which is a sub-array of the original feature. It is useful for extracting features from a given vector.

Note that duplicate features are not allowed, so there can be no overlap between selected indices. If the max value of the indices is greater than the size of the input vector, it throws an IllegalArgumentException.

输入列

参数名称类型默认值描述
inputColsVector“input”Vector to be sliced.

输出列

参数名称类型默认值描述
inputColsVector“output”Sliced vector.

算法参数

参数默认值类型是否必须描述
inputCols“input”StringnoInput column name.
inputCols“output”StringnoOutput column name.
indicesnullInteger[]yes

代码示例

import org.apache.flink.ml.feature.vectorslicer.VectorSlicer;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates a VectorSlicer instance and uses it for feature engineering. */
public class VectorSlicerExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(2.1, 3.1, 1.2, 3.1, 4.6)),
                        Row.of(Vectors.dense(1.2, 3.1, 4.6, 2.1, 3.1)));
        Table inputTable = tEnv.fromDataStream(inputStream).as("vec");

        // Creates a VectorSlicer object and initializes its parameters.
        VectorSlicer vectorSlicer =
                new VectorSlicer().setInputCol("vec").setIndices(1, 2, 3).setOutputCol("slicedVec");

        // Uses the VectorSlicer object for feature transformations.
        Table outputTable = vectorSlicer.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            Vector inputValue = (Vector) row.getField(vectorSlicer.getInputCol());

            Vector outputValue = (Vector) row.getField(vectorSlicer.getOutputCol());

            System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue);
        }
    }
}

5 推荐

5.1 Swing

An AlgoOperator which implements the Swing algorithm.

Swing is an item recall algorithm. The topology of user-item graph usually can be described as user-item-user or item-user-item, which are like ‘swing’. For example, if both user u and user v have purchased the same commodity i, they will form a relationship diagram similar to a swing. If u and v have purchased commodity j in addition to i, it is supposed i and j are similar.

See “Large Scale Product Graph Construction for Recommendation in E-commerce” by Xiaoyong Yang, Yadong Zhu and Yi Zhang.

输入列

参数名称类型默认值描述
itemColLong“item”Item id.
userColLong“user”User id.

输出列

参数名称类型默认值描述
itemColLong“item”Item id.
inputColsString“output”Top k similar items and their corresponding scores. (e.g. “item_1,0.9;item_2,0.7;item_3,0.35”)

算法参数

Below are the parameters required by Swing.

参数默认值类型是否必须描述
userCol“user”StringnoUser column name.
itemCol“item”StringnoItem column name.
maxUserNumPerItem1000IntegernoThe max number of user(purchasers) for each item. If the number of user is larger than this value, then only maxUserNumPerItem users will be sampled and considered in the computation of similarity between two items.
k100IntegernoThe max number of similar items to output for each item.
minUserBehavior10IntegernoThe min number of items for a user purchases. If the items purchased by a user is smaller than this value, then this user is filtered out while gathering data. This can affect the speed of the computation. Set minUserBehavior larger in case the swing recommendation progresses very slowly.
maxUserBehavior1000IntegernoThe max number of items for a user purchases. If the items purchased by a user is larger than this value, then this user is filtered out while gathering data. This can affect the speed of the computation. Set maxUserBehavior smaller in case the swing recommendation progresses very slowly. The IllegalArgumentException is raised if the value of maxUserBehavior is smaller than minUserBehavior.
alpha115IntegernoSmooth factor for number of users that have purchased one item. The higher alpha1 is, the less purchasing behavior contributes to the similarity score.
alpha20IntegernoSmooth factor for number of users that have purchased the two target items. The higher alpha2 is, the less purchasing behavior contributes to the similarity score.
beta0.3DoublenoDecay factor for number of users that have purchased one item. The higher beta is, the less purchasing behavior contributes to the similarity score.
outputCol“output”StringnoOutput column name.

代码示例

package org.apache.flink.ml.
### 代码示例.recommendation;

import org.apache.flink.ml.recommendation.swing.Swing;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/**
 * Simple program that creates a Swing instance and uses it to generate recommendations for items.
 */
public class SwingExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(0L, 10L),
                        Row.of(0L, 11L),
                        Row.of(0L, 12L),
                        Row.of(1L, 13L),
                        Row.of(1L, 12L),
                        Row.of(2L, 10L),
                        Row.of(2L, 11L),
                        Row.of(2L, 12L),
                        Row.of(3L, 13L),
                        Row.of(3L, 12L));

        Table inputTable = tEnv.fromDataStream(inputStream).as("user", "item");

        // Creates a Swing object and initializes its parameters.
        Swing swing = new Swing().setUserCol("user").setItemCol("item").setMinUserBehavior(1);

        // Transforms the data.
        Table[] outputTable = swing.transform(inputTable);

        // Extracts and displays the result of swing algorithm.
        for (CloseableIterator<Row> it = outputTable[0].execute().collect(); it.hasNext(); ) {
            Row row = it.next();

            long mainItem = row.getFieldAs(0);
            String itemRankScore = row.getFieldAs(1);

            System.out.printf("item: %d, top-k similar items: %s\n", mainItem, itemRankScore);
        }
    }
}

6 Regression

6.1 Linear Regression

Linear Regression is a kind of regression analysis by modeling the relationship between a scalar response and one or more explanatory variables.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColInteger“label”Label to predict.
weightColDouble“weight”Weight of sample.

输出列

参数名称类型默认值描述
predictionColInteger“prediction”Label of the max probability.

算法参数

Below are the parameters required by LinearRegressionModel.

参数默认值类型是否必须描述
featuresCol“features”StringnoFeatures column name.
predictionCol“prediction”StringnoPrediction column name.

LinearRegression needs parameters above and also below.

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
weightColnullStringnoWeight column name.
maxIter20IntegernoMaximum number of iterations.
reg0.DoublenoRegularization parameter.
elasticNet0.DoublenoElasticNet parameter.
learningRate0.1DoublenoLearning rate of optimization method.
globalBatchSize32IntegernoGlobal batch size of training algorithms.
tol1e-6DoublenoConvergence tolerance for iterative algorithms.

代码示例

import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.regression.linearregression.LinearRegression;
import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a LinearRegression model and uses it for regression. */
public class LinearRegressionExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(2, 1), 4.0, 1.0),
                        Row.of(Vectors.dense(3, 2), 7.0, 1.0),
                        Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                        Row.of(Vectors.dense(2, 4), 10.0, 1.0),
                        Row.of(Vectors.dense(2, 2), 6.0, 1.0),
                        Row.of(Vectors.dense(4, 3), 10.0, 1.0),
                        Row.of(Vectors.dense(1, 2), 5.0, 1.0),
                        Row.of(Vectors.dense(5, 3), 11.0, 1.0));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");

        // Creates a LinearRegression object and initializes its parameters.
        LinearRegression lr = new LinearRegression().setWeightCol("weight");

        // Trains the LinearRegression Model.
        LinearRegressionModel lrModel = lr.fit(inputTable);

        // Uses the LinearRegression Model for predictions.
        Table outputTable = lrModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol());
            double expectedResult = (|Double|) row.getField(lr.getLabelCol());
            double predictionResult = (|Double|) row.getField(lr.getPredictionCol());
            System.out.printf(
                    "Features: %s \tExpected Result: %s \tPrediction Result: %s\n",
                    features, expectedResult, predictionResult);
        }
    }
}

7 Stats

7.1 ChiSqTest

Chi-square Test computes the statistics of independence of variables in a contingency table, e.g., p-value, and DOF(degree of freedom) for each input feature. The contingency table is constructed from the observed categorical values.

输入列

参数名称类型默认值描述
featuresColVector“features”Feature vector.
labelColNumber“label”Label of the features.

输出列

If the output result is not flattened, the output columns are as follows.

Column name|Type|Description
“pValues”|Vector|Probability of obtaining a test statistic result at least as extreme as the one that was actually observed, assuming that the null hypothesis is true.
“degreesOfFreedom”|Int Array|Degree of freedom of the hypothesis test.
“statistics”|Vector|Test statistic.
If the output result is flattened, the output columns are as follows.

Column name|Type|Description
“featureIndex”|Int|Index of the feature in the input vectors.
“pValue”|Double|Probability of obtaining a test statistic result at least as extreme as the one that was actually observed, assuming that the null hypothesis is true.
“degreeOfFreedom”|Int|Degree of freedom of the hypothesis test.
“statistic”|Double|Test statistic.

算法参数

参数默认值类型是否必须描述
labelCol“label”StringnoLabel column name.
featuresCol“features”StringnoFeatures column name.
flattenfalseBooleannoIf false, the returned table contains only a single row, otherwise, one row per feature.

代码示例

import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that creates a ChiSqTest instance and uses it for statistics. */
public class ChiSqTestExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        Table inputTable =
                tEnv.fromDataStream(
                                env.fromElements(
                                        Row.of(0., Vectors.dense(5, 1.)),
                                        Row.of(2., Vectors.dense(6, 2.)),
                                        Row.of(1., Vectors.dense(7, 2.)),
                                        Row.of(1., Vectors.dense(5, 4.)),
                                        Row.of(0., Vectors.dense(5, 1.)),
                                        Row.of(2., Vectors.dense(6, 2.)),
                                        Row.of(1., Vectors.dense(7, 2.)),
                                        Row.of(1., Vectors.dense(5, 4.)),
                                        Row.of(2., Vectors.dense(5, 1.)),
                                        Row.of(0., Vectors.dense(5, 2.)),
                                        Row.of(0., Vectors.dense(5, 2.)),
                                        Row.of(1., Vectors.dense(9, 4.)),
                                        Row.of(1., Vectors.dense(9, 3.))))
                        .as("label", "features");

        // Creates a ChiSqTest object and initializes its parameters.
        ChiSqTest chiSqTest =
                new ChiSqTest().setFlatten(true).setFeaturesCol("features").setLabelCol("label");

        // Uses the ChiSqTest object for statistics.
        Table outputTable = chiSqTest.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            System.out.printf(
                    "Feature Index: %s\tP Value: %s\tDegree of Freedom: %s\tStatistics: %s\n",
                    row.getField("featureIndex"),
                    row.getField("pValue"),
                    row.getField("degreeOfFreedom"),
                    row.getField("statistic"));
        }
    }
}

8 Functions

Flink ML provides users with some built-in table functions for data transformations. This page gives a brief overview of them.

8.1 vectorToArray

This function converts a column of Flink ML sparse/dense vectors into a column of double arrays.

示例代码

import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

import static org.apache.flink.ml.Functions.vectorToArray;
import static org.apache.flink.table.api.Expressions.$;

/** Simple program that converts a column of dense/sparse vectors into a column of double arrays. */
public class VectorToArrayExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input vector data.
        List<Vector> vectors =
                Arrays.asList(
                        Vectors.dense(0.0, 0.0),
                        Vectors.sparse(2, new int[] {1}, new double[] {1.0}));
        Table inputTable =
                tEnv.fromDataStream(env.fromCollection(vectors, VectorTypeInfo.INSTANCE))
                        .as("vector");

        // Converts each vector to a double array.
        Table outputTable = inputTable.select($("vector"), vectorToArray($("vector")).as("array"));

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            Vector vector = row.getFieldAs("vector");
            Double[] doubleArray = row.getFieldAs("array");
            System.out.printf(
                    "Input vector: %s\tOutput double array: %s\n",
                    vector, Arrays.toString(doubleArray));
        }
    }
}

8.2 arrayToVector

This function converts a column of arrays of numeric type into a column of DenseVector instances.

示例代码

import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

import static org.apache.flink.ml.Functions.arrayToVector;
import static org.apache.flink.table.api.Expressions.$;

/** Simple program that converts a column of double arrays into a column of dense vectors. */
public class ArrayToVectorExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input double array data.
        List<double[]> doubleArrays =
                Arrays.asList(new double[] {0.0, 0.0}, new double[] {0.0, 1.0});
        Table inputTable = tEnv.fromDataStream(env.fromCollection(doubleArrays)).as("array");

        // Converts each double array to a dense vector.
        Table outputTable = inputTable.select($("array"), arrayToVector($("array")).as("vector"));

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            Double[] doubleArray = row.getFieldAs("array");
            Vector vector = row.getFieldAs("vector");
            System.out.printf(
                    "Input double array: %s\tOutput vector: %s\n",
                    Arrays.toString(doubleArray), vector);
        }
    }
}
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值