FlinkML 多元线性回归例子及训练参数提取

本文使用FlinkML机器学习库中的多元线性回归算法实现一个小demo。

一、构造训练集LabeledVector

按照文档中的步骤,我们首先构造一个LabeledVector
LabeledVector表示(label, features)
        labe:分类问题中的类,也可以是回归问题中的因变量
        features:样本特征
代码如下:

val env = ExecutionEnvironment.getExecutionEnvironment
    val dataset:DataSet[(Double,Double)] = env.fromElements((1,2),(2,4),(3,6))
    val datasetLV:DataSet[LabeledVector] = dataset.map{ x=>
      LabeledVector(x._2,DenseVector(x._1))
}

二、创建学习器

val mlr = MultipleLinearRegression()
.setIterations(10)
.setStepsize(0.5)
.setConvergenceThreshold(0.001)

三、使用训练集进行线性拟合

mlr.fit(datasetLV)

官方文档中可以使用Predict函数进行评估,但是Predict函数在[Flink-2116]commit中已经被删除,文档没有及时更新,目前需要我们自己编写函数进行评估。

四、提取训练完成后的参数

参数在mlr.weightsOption中,可以使用以下代码进行提取:

val weights = mlr.weightsOption match {
      case Some(weights) => weights.collect()
      case None => throw new Exception("Could not calculate the weights.")
    }
//    println(weights.toString())
    val vector =  weights.iterator.next()
    val theat0 = vector.weights.apply(0)
    val theat1 = weights.iterator.next().intercept

官网文档地址:

https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/libs/ml/multiple_linear_regression.html

注:
调用predict函数会报错:
There is no PredictOperation defined for org.apache.flink.ml.regression.MultipleLinearRegression which takes a DataSet[org.apache.flink.ml.common.LabeledVector] as input.

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值