二 Spark机器学习MLlib: LogisticRegression

一 MLlib简单介绍

MLllib目前分为两个代码包:

spark.mllib 包含基于RDD的原始算法API。
spark.ml 则提供了基于DataFrames 高层次的API,可以用来构建机器学习管道。

本文用基于DataFrame的API,DataFrame结构与MySQL表基本一致,处理数据比较方便。

基于DataFrame的API,包名为:org.apache.spark.ml.*;
数据对象引用地址为:org.apache.spark.sql.*;

基于JavaRdd的API,包名为: org.apache.spark.mllib.*;

MLlib指南

二 LogisticRegression 模型

LogisticRegression的损失函数为:

L(w;x,y):=ln(1+exp(ywTx))

预测函数为: f(z)=11+ez;z=wTx

f(z) 大于0.5决策函数取1,否则取0。

求解模型参数有两种方法,一种梯度下降法,另一种是 L-BFGS.

梯度下降法详细情况参看博客

三 Spark计算

  • 引用包
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
  • 构造训练数据和测试数据。
//构造训练数据。
 List<Row> dataTraining = Arrays.asList(
        RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)),
        RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)),
        RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)),
        RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5))
    );
 StructType schema = new StructType(
        new StructField[] { 
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty()) 
        }
    );
 Dataset<Row> training = spark.createDataFrame(dataTraining, schema);
 //测试数据
 List<Row> dataTest = Arrays.asList(
        RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
        RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)),
        RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5))
    );
Dataset<Row> test = spark.createDataFrame(dataTest, schema);
  • 训练模型,测试模型
//新建模型
LogisticRegression lr = new LogisticRegression();
//设置参数,迭代10次,正则化系数0.01
lr.setMaxIter(10).setRegParam(0.01);
//训练模型
LogisticRegressionModel model1 = lr.fit(training);
//决策
Dataset<Row> results = model1.transform(test);
//查看模型参数:
System.out.println(
    "Model was fit using parameters: " + model1.parent().extractParamMap()
    );
//查看结果
results
    .collectAsList()
    .forEach(
        row->System.out.println(
            "(" + row.get(0) + ", " + row.get(1) + ") -> prediction=" + row.get(3)
        )
    );
  • 运行结果
Model was fit using parameters: {
    logreg_10073fbf67d3-aggregationDepth: 2,
    logreg_10073fbf67d3-elasticNetParam: 0.0,
    logreg_10073fbf67d3-family: auto,
    logreg_10073fbf67d3-featuresCol: features,
    logreg_10073fbf67d3-fitIntercept: true,
    logreg_10073fbf67d3-labelCol: label,
    logreg_10073fbf67d3-maxIter: 10,
    logreg_10073fbf67d3-predictionCol: prediction,
    logreg_10073fbf67d3-probabilityCol: probability,
    logreg_10073fbf67d3-rawPredictionCol: rawPrediction,
    logreg_10073fbf67d3-regParam: 0.01,
    logreg_10073fbf67d3-standardization: true,
    logreg_10073fbf67d3-threshold: 0.5,
    logreg_10073fbf67d3-tol: 1.0E-6
}
17/09/19 23:29:15 INFO CodeGenerator: Code generated in 68.426465 ms
17/09/19 23:29:15 INFO CodeGenerator: Code generated in 35.931395 ms
//注 prediction 数组表示决策结果属于{0,1}的概率。
(1.0, [-1.0,1.5,1.3]) -> prediction=[0.0013759947069214296,0.9986240052930786]
(0.0, [3.0,2.0,-0.1]) -> prediction=[0.9816604009374171,0.018339599062582968]
(1.0, [0.0,2.2,-1.5]) -> prediction=[0.0016981475578358401,0.9983018524421641]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值