(1)描述
逻辑回归 Logistic regression
L-BFGS支持二分逻辑回归和多项式逻辑回归,SGD只支持二分逻辑回归。L-BFGS不支持L1正则化,SGD版本支持L1正则化。当L1不是必须时,推荐使用L-BFGS版本,它通过拟牛顿近似Heaaian矩阵收敛的更快更准。
(2)样例数据
1 159:124 160:253 161:255 162:63 186:96 187:244 188:251 189:253 190:62 214:127 215:251 216:251 217:253 218:62 241:68 242:236 243:251 244:211 245:31 246:8 268:60 269:228 270:251 271:251 272:94 296:155 297:253 298:253 299:189 323:20 324:253 325:251 326:235 327:66 350:32 351:205 352:253 353:251 354:126 378:104 379:251 380:253 381:184 382:15 405:80 406:240 407:251 408:193 409:23 432:32 433:253 434:253 435:253 436:159 460:151 461:251 462:251 463:251 464:39 487:48 488:221 489:251 490:251 491:172 515:234 516:251 517:251 518:196 519:12 543:253 544:251 545:251 546:89 570:159 571:255 572:253 573:253 574:31 597:48 598:228 599:253 600:247 601:140 602:8 625:64 626:251 627:253 628:220 653:64 654:251 655:253 656:220 681:24 682:193 683:253 684:220
(3)测试代码
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
SparkContext sc = new SparkContext(conf);
// $example on$
String path = "sample_binary_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];
// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
.run(training.rdd());
// Clear the prediction threshold so the model will return probabilities
model.clearThreshold();
// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);
System.out.println("------------------------------------->"+predictionAndLabels.take(10));
// Get evaluation metrics.
BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());
// Precision by threshold
JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
System.out.println("Precision by threshold: " + precision.toArray());
// Recall by threshold
JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
System.out.println("Recall by threshold: " + recall.toArray());
// F Score by threshold
JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
System.out.println("F1 Score by threshold: " + f1Score.toArray());
JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
System.out.println("F2 Score by threshold: " + f2Score.toArray());
// Precision-recall curve
JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
System.out.println("Precision-recall curve: " + prc.toArray());
// Thresholds
JavaRDD<Double> thresholds = precision.map(
new Function<Tuple2<Object, Object>, Double>() {
public Double call(Tuple2<Object, Object> t) {
return new Double(t._1().toString());
}
}
);
// ROC Curve
JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
System.out.println("ROC curve: " + roc.toArray());
// AUPRC
System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());
// AUROC
System.out.println("Area under ROC = " + metrics.areaUnderROC());
// Save and load model
//model.save(sc, "target/tmp/LogisticRegressionModel");
// LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
// "target/tmp/LogisticRegressionModel");
// $example off$
}
(4)测试结果
spark-submit --class org.apache.spark.examples.mllib.JavaBinaryClassificationMetricsExample --master yarn --deploy-mode cluster --driver-memory 1G --executor-memory 1G --executor-cores 3 spark.jar
[(0.9999999980030561,1.0), (0.9999993163872242,1.0), (1.3972031658836056E-7,0.0), (0.9999999921111843,1.0), (9.446085185778696E-11,0.0), (6.450291392815944E-10,0.0), (0.9999960784492523,1.0), (0.9999998349761544,1.0), (0.9999999993914117,1.0), (6.483613249378359E-10,0.0)]