一、监督学习类问题的数据加载
对于监督学习类问题,flinkml通常使用LabeledVector类来表示(label, features)实例。
以官方文档中使用的数据集为例,数据格式如下:
代码:
package cn.xsy.flink.ml;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.math.DenseVector;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class LoadDataDemo {
public static void main(String[] args) throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
//读文件
DataSource<String> StringDataSource = env.readTextFile("D:\\data\\flinkML\\Haberman\\haberman.data");
//将<String> 转换为<Double, DenseVector>
DataSet<LabeledVector> survivalLV = StringDataSource.map(new MapFunction<String, LabeledVector>() {
public LabeledVector map(String value) throws Exception {
String[] split = value.split(",");
Stream<Double> doubleStream = Stream.of(split).map(s -> Double.parseDouble(s));
List<Double> collect = doubleStream.collect(Collectors.toList());
double[] doubles = ArrayUtils.toPrimitive(collect.subList(0, 3).toArray(new Double[3]));
return new LabeledVector(collect.get(3), new DenseVector(doubles));
}
});
survivalLV.print();
}
}
输出数据格式:
二、SVM算法例子
以官方文档中使用的astroparticle二元分类数据集为例,格式如下:
代码实现:
package cn.xsy.flink.ml;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.api.scala.ExecutionEnvironment;
import org.apache.flink.ml.MLUtils;
import org.apache.flink.ml.classification.SVM;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.pipeline.EvaluateDataSetOperation;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictOperation;
import org.apache.flink.ml.pipeline.Predictor$;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import static org.apache.flink.ml.classification.SVM.fitSVM;
public class SVMDemo {
public static void main(String[] args) {
ExecutionEnvironment environment = ExecutionEnvironment.getExecutionEnvironment();
DataSet<LabeledVector> astroTrainLibSVM = MLUtils.readLibSVM(environment, "D:\\data\\flinkML\\astroparticle\\svmguide1");
DataSet<LabeledVector> astroTestLibSVM = MLUtils.readLibSVM(environment, "D:\\data\\flinkML\\astroparticle\\svmguide1.t");
DataSet<LabeledVector> astroTrain = astroTrainLibSVM
//将label转换为-1和1
.map(new Normalizer(), TypeInformation.of(LabeledVector.class), ClassTag$.MODULE$.apply(LabeledVector.class));
DataSet<Tuple2<Vector, Double>> astroTest = astroTestLibSVM
//将label转换为-1和1
.map(new Normalizer(), TypeInformation.of(LabeledVector.class), ClassTag$.MODULE$.apply(LabeledVector.class))
//将<double, SparseVector>转换为<SparseVector, Double>
.map(new MapFunction<LabeledVector, Tuple2<Vector, Double>>() {
@Override
public Tuple2<Vector, Double> map(LabeledVector value) throws Exception {
return new Tuple2<>(value.vector(), value.label());
}
}, TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() {}), ClassTag$.MODULE$.apply(Tuple2.class));
// astroTrain.print();
// astroTest.print();
//SVM分类器
SVM svm = new SVM();
svm.setBlocks(environment.getParallelism())
.setIterations(100)
.setRegularization(0.001)
.setStepsize(0.1)
.setSeed(42);
//调用SVM训练方法
svm.fit(astroTrain, ParameterMap.Empty(), (FitOperation)fitSVM());
EvaluateDataSetOperation evaluateDataSetOperation = Predictor$.MODULE$.defaultEvaluateDataSetOperation(
(PredictOperation) SVM.predictVectors(),
TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() {}),
TypeInformation.of(Double.class));
//对测试集进行预测,返回(真值,预测)
DataSet evaluate = svm.evaluate(astroTest, ParameterMap.Empty(), evaluateDataSetOperation);
evaluate.print();
}
public static final class Normalizer implements MapFunction<LabeledVector, LabeledVector> {
@Override
public LabeledVector map(LabeledVector value) throws Exception {
return new LabeledVector(value.label() > 0.0 ? 1.0 : -1.0, value.vector());
}
}
}
输出结果:
官方文档地址:
https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/libs/ml/quickstart.html