Spark MLib机器学习流水线Pipeline
DataFrame作为基本的数据抽象。
Transform:转化器,传入DataFrame转换成新的DataFrame。
Estimator:评估器,fit训练得到模型。
Pipeline:流水线,多步骤组合。
构建Pipeline:
// 构建一个机器学习工作流
// 在原始DataFrame上调用Pipeline.fit()方法,它具有原始文本文档和标签。
// Tokenizer.transform()方法将原始文本文档拆分为单词,向DataFrame添加一个带有单词的新列。
// HashingTF.transform()方法将字列转换为特征向量,向这些向量添加一个新列到DataFrame。
import org.apache.spark.sql.SparkSession
val spark=SparkSession.builder().master("local").appName("pipelines").getOrCreate()
import spark.implicits._//和SQLContext一样开启隐式转换
// 引入相关包和训练数据集
import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.{Pipeline,PipelineModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.Row
val training = spark.createDataFrame(Seq(
(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0)
)).toDF("id", "text", "label")
// pipelines包含tokenizer,hashingTF,lr三个步骤
// 将原始文本拆分成单词
val tokenizer=new Tokenizer().setInputCol("text").setOutputCol("words")
// 将单词转化成特征向量
val hashingTF=new HashingTF().setNumFeatures(1000).
setInputCol(tokenizer.getOutputCol).
setOutputCol("features")
// 创建一个机器学习模型并设置参数
val lr=new LogisticRegression().
setMaxIter(10).
setRegParam(0.01)
// 组织一个pipeline
val pipeline=new Pipeline().
setStages(Array(tokenizer,hashingTF,lr))
//通过piprline的fit方法传入训练数据创建model
val model=pipeline.fit(training)
// 测试数据
val test = spark.createDataFrame(Seq(
(4L, "spark i j k"),
(5L, "l m n"),
(6L, "spark a"),
(7L, "apache hadoop")
)).toDF("id", "text")
// 通过transform方法预测
model.transform(test).select("id", "text", "probability", "prediction").
collect().
foreach{ case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
// 看到预测的概率值
// 由于训练数据集较少,如果有更多的测试数据进行学习,预测的准确率将会有显著提升
测试数据的预测结果:
(4, spark i j k) --> prob=[0.540643354485232,0.45935664551476796], prediction=0.0
(5, l m n) --> prob=[0.9334382627383527,0.06656173726164716], prediction=0.0
(6, spark a) --> prob=[0.1504143004807332,0.8495856995192668], prediction=1.0
(7, apache hadoop) --> prob=[0.9768636139518375,0.02313638604816238], prediction=0.0