Flinkml监督学习的数据加载以及svm算法例子(java实现)

一、监督学习类问题的数据加载

对于监督学习类问题,flinkml通常使用LabeledVector类来表示(label, features)实例。
以官方文档中使用的数据集为例,数据格式如下:
在这里插入图片描述
代码:

package cn.xsy.flink.ml;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.math.DenseVector;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class LoadDataDemo {

    public static void main(String[] args) throws Exception {

        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        //读文件
        DataSource<String> StringDataSource = env.readTextFile("D:\\data\\flinkML\\Haberman\\haberman.data");
        //将<String> 转换为<Double, DenseVector>
        DataSet<LabeledVector> survivalLV = StringDataSource.map(new MapFunction<String, LabeledVector>() {
            public LabeledVector map(String value) throws Exception {
                String[] split = value.split(",");
                Stream<Double> doubleStream = Stream.of(split).map(s -> Double.parseDouble(s));
                List<Double> collect = doubleStream.collect(Collectors.toList());
                double[] doubles = ArrayUtils.toPrimitive(collect.subList(0, 3).toArray(new Double[3]));
                return new LabeledVector(collect.get(3), new DenseVector(doubles));
            }
        });
        survivalLV.print();
    }
}

输出数据格式:
在这里插入图片描述

二、SVM算法例子

以官方文档中使用的astroparticle二元分类数据集为例,格式如下:
在这里插入图片描述
代码实现:

package cn.xsy.flink.ml;

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.api.scala.ExecutionEnvironment;
import org.apache.flink.ml.MLUtils;
import org.apache.flink.ml.classification.SVM;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.pipeline.EvaluateDataSetOperation;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictOperation;
import org.apache.flink.ml.pipeline.Predictor$;
import scala.Tuple2;
import scala.reflect.ClassTag$;

import static org.apache.flink.ml.classification.SVM.fitSVM;

public class SVMDemo {

    public static void main(String[] args) {

        ExecutionEnvironment environment = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<LabeledVector> astroTrainLibSVM = MLUtils.readLibSVM(environment, "D:\\data\\flinkML\\astroparticle\\svmguide1");
        DataSet<LabeledVector> astroTestLibSVM = MLUtils.readLibSVM(environment, "D:\\data\\flinkML\\astroparticle\\svmguide1.t");

        DataSet<LabeledVector> astroTrain = astroTrainLibSVM
                //将label转换为-1和1
                .map(new Normalizer(), TypeInformation.of(LabeledVector.class), ClassTag$.MODULE$.apply(LabeledVector.class));
        DataSet<Tuple2<Vector, Double>> astroTest = astroTestLibSVM
                //将label转换为-1和1
                .map(new Normalizer(), TypeInformation.of(LabeledVector.class), ClassTag$.MODULE$.apply(LabeledVector.class))
                //将<double, SparseVector>转换为<SparseVector, Double>
                .map(new MapFunction<LabeledVector, Tuple2<Vector, Double>>() {
                    @Override
                    public Tuple2<Vector, Double> map(LabeledVector value) throws Exception {

                        return new Tuple2<>(value.vector(), value.label());

                    }
                }, TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() {}), ClassTag$.MODULE$.apply(Tuple2.class));

//        astroTrain.print();
//        astroTest.print();

        //SVM分类器
        SVM svm = new SVM();
        svm.setBlocks(environment.getParallelism())
                .setIterations(100)
                .setRegularization(0.001)
                .setStepsize(0.1)
                .setSeed(42);
        //调用SVM训练方法
        svm.fit(astroTrain, ParameterMap.Empty(), (FitOperation)fitSVM());

        EvaluateDataSetOperation evaluateDataSetOperation = Predictor$.MODULE$.defaultEvaluateDataSetOperation(
                (PredictOperation) SVM.predictVectors(),
                TypeInformation.of(new TypeHint<Tuple2<Vector, Double>>() {}),
                TypeInformation.of(Double.class));
        //对测试集进行预测,返回(真值,预测)
        DataSet evaluate = svm.evaluate(astroTest, ParameterMap.Empty(), evaluateDataSetOperation);

        evaluate.print();

    }

    public static final class Normalizer implements MapFunction<LabeledVector, LabeledVector> {

        @Override
        public LabeledVector map(LabeledVector value) throws Exception {
            return new LabeledVector(value.label() > 0.0 ? 1.0 : -1.0, value.vector());
        }
    }
}

输出结果:
在这里插入图片描述

官方文档地址:

https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/libs/ml/quickstart.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值