Spark-MLlib的快速使用之十四(逻辑回归 一元逻辑回归)

(1)描述

逻辑回归 Logistic regression

L-BFGS支持二分逻辑回归和多项式逻辑回归,SGD只支持二分逻辑回归。L-BFGS不支持L1正则化,SGD版本支持L1正则化。当L1不是必须时,推荐使用L-BFGS版本,它通过拟牛顿近似Heaaian矩阵收敛的更快更准。

(2)样例数据

1 159:124 160:253 161:255 162:63 186:96 187:244 188:251 189:253 190:62 214:127 215:251 216:251 217:253 218:62 241:68 242:236 243:251 244:211 245:31 246:8 268:60 269:228 270:251 271:251 272:94 296:155 297:253 298:253 299:189 323:20 324:253 325:251 326:235 327:66 350:32 351:205 352:253 353:251 354:126 378:104 379:251 380:253 381:184 382:15 405:80 406:240 407:251 408:193 409:23 432:32 433:253 434:253 435:253 436:159 460:151 461:251 462:251 463:251 464:39 487:48 488:221 489:251 490:251 491:172 515:234 516:251 517:251 518:196 519:12 543:253 544:251 545:251 546:89 570:159 571:255 572:253 573:253 574:31 597:48 598:228 599:253 600:247 601:140 602:8 625:64 626:251 627:253 628:220 653:64 654:251 655:253 656:220 681:24 682:193 683:253 684:220

(3)测试代码

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");

SparkContext sc = new SparkContext(conf);

// $example on$

String path = "sample_binary_classification_data.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

// Split initial RDD into two... [60% training data, 40% testing data].

JavaRDD<LabeledPoint>[] splits =

data.randomSplit(new double[]{0.6, 0.4}, 11L);

JavaRDD<LabeledPoint> training = splits[0].cache();

JavaRDD<LabeledPoint> test = splits[1];

// Run training algorithm to build the model.

final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()

.setNumClasses(2)

.run(training.rdd());

// Clear the prediction threshold so the model will return probabilities

model.clearThreshold();

// Compute raw scores on the test set.

JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(

new Function<LabeledPoint, Tuple2<Object, Object>>() {

public Tuple2<Object, Object> call(LabeledPoint p) {

Double prediction = model.predict(p.features());

return new Tuple2<Object, Object>(prediction, p.label());

}

}

);

System.out.println("------------------------------------->"+predictionAndLabels.take(10));

// Get evaluation metrics.

BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());

// Precision by threshold

JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();

System.out.println("Precision by threshold: " + precision.toArray());

// Recall by threshold

JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();

System.out.println("Recall by threshold: " + recall.toArray());

// F Score by threshold

JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();

System.out.println("F1 Score by threshold: " + f1Score.toArray());

JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();

System.out.println("F2 Score by threshold: " + f2Score.toArray());

// Precision-recall curve

JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();

System.out.println("Precision-recall curve: " + prc.toArray());

// Thresholds

JavaRDD<Double> thresholds = precision.map(

new Function<Tuple2<Object, Object>, Double>() {

public Double call(Tuple2<Object, Object> t) {

return new Double(t._1().toString());

}

}

);

// ROC Curve

JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();

System.out.println("ROC curve: " + roc.toArray());

// AUPRC

System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());

// AUROC

System.out.println("Area under ROC = " + metrics.areaUnderROC());

// Save and load model

//model.save(sc, "target/tmp/LogisticRegressionModel");

// LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,

// "target/tmp/LogisticRegressionModel");

// $example off$

}

(4)测试结果

spark-submit --class org.apache.spark.examples.mllib.JavaBinaryClassificationMetricsExample --master yarn --deploy-mode cluster --driver-memory 1G --executor-memory 1G --executor-cores 3  spark.jar

[(0.9999999980030561,1.0), (0.9999993163872242,1.0), (1.3972031658836056E-7,0.0), (0.9999999921111843,1.0), (9.446085185778696E-11,0.0), (6.450291392815944E-10,0.0), (0.9999960784492523,1.0), (0.9999998349761544,1.0), (0.9999999993914117,1.0), (6.483613249378359E-10,0.0)]

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值