概念
基于DataFrame API机器学习库Spark ML中提供Pipeline管道构建模型,方便实际项目中开发与部署。通俗地解释,就是一条流水线,它将多个处理步骤或组件串联起来,形成一个有序的工作流程。
特点
流水线式设计
Pipeline管道模型就像一个工厂中的流水线,原材料(数据)进入流水线的一端,经过一系列的加工处理(如数据清洗、转换、特征提取等),最终从流水线的另一端输出成品(处理后的数据或结果)。
提升效率
通过流水线式的设计,Pipeline管道模型可以显著提升数据处理的效率。在传统的数据处理过程中,可能需要手动编写大量的代码来执行每一步操作,并且需要手动管理数据在不同步骤之间的传递。而Pipeline管道模型将这些步骤封装在一起,自动管理数据的传递和处理,从而减少了重复劳动和出错的可能性。
降低复杂度
对于复杂的数据处理任务,Pipeline管道模型可以将任务拆分成多个简单的步骤,每个步骤只负责完成一部分工作。这种分而治之的策略可以降低任务的复杂度,使得每个步骤都更加容易理解和实现。
灵活性和可扩展性
Pipeline管道模型具有良好的灵活性和可扩展性。用户可以根据需要添加或删除步骤,也可以替换某个步骤的实现方式。这种灵活性使得Pipeline管道模型能够适应各种不同的数据处理需求。
应用领域
具体来说,Pipeline管道模型在以下场景中有着广泛的应用:
机器学习中的数据预处理
在机器学习任务中,数据预处理通常包括数据清洗、转换、特征提取等多个步骤。通过Pipeline管道模型,可以将这些步骤串联起来,形成一个完整的预处理流程,提高数据处理的效率和准确性。
Redis客户端的批处理
在Redis客户端中,使用Pipeline可以批量执行多个命令,减少网络交互的次数,从而提高性能。例如,当有多个命令需要被“及时的”提交,并且这些命令对响应结果没有互相依赖时,可以使用Pipeline进行批处理。
网络编程中的事件处理
在网络编程中,ChannelPipeline是一个处理或拦截channel的进站事件和出站事件的双向链表。通过增加或删除ChannelHandler,可以实现对不同业务逻辑的处理,提高网络编程的灵活性和可扩展性。
简单举个栗子:
机器学习中的数据预处理Pipeline
数据清洗:Pipeline的第一个步骤通常是数据清洗,它负责处理原始数据中的缺失值、异常值和重复数据等问题。例如,在一个预测房价的机器学习项目中,数据清洗可能包括删除含有缺失值的记录、填充缺失值(如使用均值、中位数等)、处理异常值(如设置阈值进行裁剪)等步骤。
特征工程:接下来是特征工程步骤,它涉及特征选择、特征转换和特征降维等技术。在这个例子中,特征工程可能包括选择对房价预测有重要影响的特征(如房屋面积、位置、房龄等)、进行特征缩放(如归一化、标准化)、创建新的特征(如房屋面积与房龄的比值)等。
模型训练和评估:经过特征工程后,数据被输入到选定的机器学习模型中进行训练和评估。训练过程中,可以使用交叉验证等技术来评估模型的性能,并通过调整超参数来优化模型。
自动化和集成:通过Pipeline,这些步骤被整合在一起,形成一个自动化的数据预处理和模型训练流程。这样,当新的数据到来时,只需要将数据输入到Pipeline中,就可以自动完成数据预处理、特征工程、模型训练和评估等步骤,大大提高了工作效率。
网络编程中的ChannelPipeline
事件处理:在网络编程中,ChannelPipeline是一个处理或拦截channel的进站事件和出站事件的双向链表。当一个网络事件发生时(如数据接收、连接建立等),它会在ChannelPipeline中传递和处理。
添加和删除Handler:用户可以在ChannelPipeline中添加或删除ChannelHandler来处理不同的事件和业务逻辑。例如,在一个TCP服务器中,可以添加一个Decoder Handler来解码接收到的字节数据为应用层消息,然后添加一个业务处理Handler来处理这些消息。
灵活性:由于ChannelPipeline的灵活性,用户可以轻松地修改事件处理流程,以适应不同的业务需求和场景。例如,可以添加新的Handler来处理新的业务逻辑,或者删除不再需要的Handler来简化流程。
Redis客户端的Pipeline批处理
批量操作:在Redis客户端中,Pipeline允许用户将多个命令打包成一个请求发送给Redis服务器,从而减少网络交互的次数,提高性能。例如,假设有一个需求需要同时设置10000个键值对,使用Pipeline可以将这10000个命令打包成一个请求发送,而不是发送10000个单独的请求。
性能提升:通过减少网络交互的次数,Pipeline可以显著提高Redis客户端的性能。在上面的例子中,使用Pipeline可以大大减少网络延迟和传输时间,从而加快键值对的设置速度。
管道Pipeline概念:将多个Transformer转换器和Estimators模型学习器按照以来顺序组工作流WorkFlow形式,方便数据集的特征转换和模型训练以及预测。
概念图
转变成
关键词
DataFrame:数据框,一种数据结构,来源于Spark SQL中,DataFrame = Dataset[Row],存储要训练和测试的数据集;
Transformer:转换器,一种算法Algorithm,必须实现transform方法。比如:模型Model就是一个转换器,将输入的数据集DataFrame转换为预测结果的数据集DataFrame;
Estimator:估计器或者模型学习器,将数据集DataFrame转换为一个Transformer,实现fit()方法,输入一个DataFrame并产生一个Model,即一个Transformer(转换器);
Pipeline:管道,
Parameter:参数,无论是转换器Transformer还是模型学习器Estimator都是一个算法,使用算法的时候必然会有参数。
一个经典的机器学习构建包括若干个过程:
以上四个步骤可以抽象为一个包括多个步骤的流水线式工作,从数据收集开始,至输出需要的最终结果。因此,对以上多个步骤进行抽象建模,简化为流水线式工作流程,则存在着可行性,对利用SparkMLlib进行机器学习的用户来说,流水线式机器学习比单个步骤独立建模更加高效、易用。
一个pipeline在结构上会包含以一个或多个Stage,每一个Stage都会完成一个认为,如数据集处理转化,模型训练,参数设置或数据预测等,这样的Stage在ML里按照处理问题类型的不同都有相应的定义和实现。两个主要的stage为Transformer和Estimator。
Transformer模型:
用来操作一个DataFrame数据并生成另一个DataFrame数据,比如svm模型,、一个特征提取工具,都可以抽象为一个Transformer。
Estimator算法:
用来做模型拟合,生成一个Transformer。
Pipeline官方案例
将【官方案例】决策树分类代码,改为使用Pipeline构建模型与预测,流程示意图如下:
代码实现
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
object PipelineClassification {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName.stripSuffix("$"))
.master("local[4]")
.getOrCreate()
import org.apache.spark.sql.functions._
import spark.implicits._
// 1. 加载数据
val dataframe: DataFrame = spark.read
.format("libsvm")
.load("datas/mllib/sample_libsvm_data.txt")
//dataframe.printSchema()
//dataframe.show(10, truncate = false)
// 划分数据集:训练数据和测试数据
val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))
// 2. 特征工程:特征提取、特征转换及特征选择
// 2.1. 将标签值label,转换为索引,从0开始,到 K-1
val labelIndexer: StringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("index_label")
.fit(dataframe)
val df1: DataFrame = labelIndexer.transform(dataframe)
// 2.2. 对类别特征数据进行特殊处理, 当每列的值的个数小于等于设置K,那么此列数据被当做类别特征,自动进行索引转换
val featureIndexer: VectorIndexerModel = new VectorIndexer()
.setInputCol("features")
.setOutputCol("index_features")
// TODO: 表示哪些特征列为类别特征列,并且将特征列的特征值进行索引化转换操作
.setMaxCategories(4) // 类别特征最大类别个数
.fit(df1)
val df2: DataFrame = featureIndexer.transform(df1)
// 3. 使用决策树算法构建分类模型
val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
.setLabelCol("index_label")
.setFeaturesCol("index_features")
// 设置决策树算法相关超参数
.setImpurity("gini") // 也可以是香农熵:entropy
.setMaxDepth(5)
.setMaxBins(32) // 此值必须大于等于类别特征类别个数
// TODO: 4. 构建Pipeline管道,设置Stage,每个Stage要么是算法(模型学习器Estimator),要么是模型(转换器Transformer)
val pipeline: Pipeline = new Pipeline()
// 设置Stage,依赖顺序
.setStages(
Array(labelIndexer, featureIndexer, dtc)
)
// 应用数据集,训练管道模型
val pipelineModel: PipelineModel = pipeline.fit(trainingDF)
// TODO: 获取决策树分类模型
val dtcModel: DecisionTreeClassificationModel = pipelineModel
.stages(2)
.asInstanceOf[DecisionTreeClassificationModel]
println(dtcModel.toDebugString)
// 4. 模型评估
val predictionDF: DataFrame = pipelineModel.transform(testingDF)
predictionDF.printSchema()
predictionDF
.select($"index_label", $"probability", $"prediction")
.show(20, truncate = false)
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("index_label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy: Double = evaluator.evaluate(predictionDF)
println(s"Accuracy = $accuracy")
spark.stop()
}
}
扩展补充部分
将刚刚学到的Pipeline模型写入上一篇USG用户购物性别标签中,完善代码:
import cn.itcast.tags.models.{AbstractModel, ModelType}
import cn.itcast.tags.tools.{MLModelTools, TagTools}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.functions._
/**
* 挖掘类型标签模型开发:用户购物性别标签模型
*/
class UsgTagModel extends AbstractModel("用户购物性别USG", ModelType.ML){
/*
378 用户购物性别
379 男 0
380 女 1
381 中性 -1
*/
override def doTag(businessDF: DataFrame, tagDF: DataFrame): DataFrame = {
val session: SparkSession = businessDF.sparkSession
import session.implicits._
/*
root
|-- cordersn: string (nullable = true)
|-- ogcolor: string (nullable = true)
|-- producttype: string (nullable = true)
*/
//businessDF.printSchema()
//businessDF.show(10, truncate = false)
//tagDF.printSchema()
/*
+---+----+----+-----+
|id |name|rule|level|
+---+----+----+-----+
|379|男 |0 |5 |
|380|女 |1 |5 |
|381|中性 |-1 |5 |
+---+----+----+-----+
*/
//tagDF.filter($"level".equalTo(5)).show(10, truncate = false)
// 1. 获取订单表数据tbl_tag_orders,与订单商品表数据关联获取会员ID
val ordersDF: DataFrame = spark.read
.format("hbase")
.option("zkHosts", "bigdata-cdh01.itcast.cn")
.option("zkPort", "2181")
.option("hbaseTable", "tbl_tag_orders")
.option("family", "detail")
.option("selectFields", "memberid,ordersn")
.load()
//ordersDF.printSchema()
//ordersDF.show(10, truncate = false)
// 2. 加载维度表数据:tbl_dim_colors(颜色)、tbl_dim_products(产品)
// 2.1 加载颜色维度表数据
val colorsDF: DataFrame = {
spark.read
.format("jdbc")
.option("driver", "com.mysql.jdbc.Driver")
.option("url",
"jdbc:mysql://bigdata-cdh01.itcast.cn:3306/?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC")
.option("dbtable", "profile_tags.tbl_dim_colors")
.option("user", "root")
.option("password", "123456")
.load()
}
/*
root
|-- id: integer (nullable = false)
|-- color_name: string (nullable = true)
*/
//colorsDF.printSchema()
//colorsDF.show(30, truncate = false)
// 2.2. 构建颜色WHEN语句
val colorColumn: Column = {
// 声明变量
var colorCol: Column = null
colorsDF
.as[(Int, String)].rdd
.collectAsMap()
.foreach{case (colorId, colorName) =>
if(null == colorCol){
colorCol = when($"ogcolor".equalTo(colorName), colorId)
}else{
colorCol = colorCol.when($"ogcolor".equalTo(colorName), colorId)
}
}
colorCol = colorCol.otherwise(0).as("color")
// 返回
colorCol
}
// 2.3 加载商品维度表数据
val productsDF: DataFrame = {
spark.read
.format("jdbc")
.option("driver", "com.mysql.jdbc.Driver")
.option("url",
"jdbc:mysql://bigdata-cdh01.itcast.cn:3306/?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC")
.option("dbtable", "profile_tags.tbl_dim_products")
.option("user", "root")
.option("password", "123456")
.load()
}
/*
root
|-- id: integer (nullable = false)
|-- product_name: string (nullable = true)
*/
//productsDF.printSchema()
///productsDF.show(30, truncate = false)
// 2.4. 构建商品类别WHEN语句
var productColumn: Column = {
// 声明变量
var productCol: Column = null
productsDF
.as[(Int, String)].rdd
.collectAsMap()
.foreach{case (productId, productName) =>
if(null == productCol){
productCol = when($"producttype".equalTo(productName), productId)
}else{
productCol = productCol.when($"producttype".equalTo(productName), productId)
}
}
productCol = productCol.otherwise(0).as("product")
// 返回
productCol
}
// 2.5. 根据运营规则标注的部分数据
val labelColumn: Column = {
when($"ogcolor".equalTo("樱花粉")
.or($"ogcolor".equalTo("白色"))
.or($"ogcolor".equalTo("香槟色"))
.or($"ogcolor".equalTo("香槟金"))
.or($"productType".equalTo("料理机"))
.or($"productType".equalTo("挂烫机"))
.or($"productType".equalTo("吸尘器/除螨仪")), 1) //女
.otherwise(0)//男
.alias("label")//决策树预测label
}
// 3. 关联订单数据,颜色维度和商品类别维度数据
val goodsDF: DataFrame = businessDF
// 关联订单数据:
.join(ordersDF, businessDF("cordersn") === ordersDF("ordersn"))
// 选择所需字段,使用when判断函数
.select(
$"memberid".as("userId"), //
colorColumn, // 颜色ColorColumn
productColumn, // 产品类别ProductColumn
// 依据规则标注商品性别
labelColumn
)
//goodsDF.printSchema()
//goodsDF.show(100, truncate = false)
// TODO: 直接使用标注数据,给用户打标签
//val predictionDF: DataFrame = goodsDF.select($"userId", $"label".as("prediction"))
// TODO: 实际上,需要使用标注数据(给每个用户购买每个商品打上性别字段),构建算法模型,并使用模型预测值性别字段值
//val featuresDF: DataFrame = featuresTransform(goodsDF)
//val dtcModel: DecisionTreeClassificationModel = trainModel(featuresDF)
//val predictionDF: DataFrame = dtcModel.transform(featuresDF)
// TODO: 使用数据集训练管道模型PipelineModel
//val pipelineModel: PipelineModel = trainPipelineModel(goodsDF)
//val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
// TODO: 使用训练验证分割方式,设置超参数训练处最佳模型
//val pipelineModel: PipelineModel = trainBestModel(goodsDF)
//val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
// TODO: 使用交叉验证方式,设置超参数训练获取最佳模型
/*
当模型存在时,直接加载模型;如果不存在,训练获取最佳模型,并保存
*/
val pipelineModel: PipelineModel = MLModelTools.loadModel(
goodsDF, "usg", this.getClass
).asInstanceOf[PipelineModel]
val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
// 4. 按照用户ID分组,统计每个用户购物男性或女性商品个数及占比
val genderDF: DataFrame = predictionDF
.groupBy($"userId")
.agg(
count($"userId").as("total"), // 某个用户购物商品总数
// 判断label为0时,表示为男性商品,设置为1,使用sum函数累加
sum(
when($"prediction".equalTo(0), 1).otherwise(0)
).as("maleTotal"),
// 判断label为1时,表示为女性商品,设置为1,使用sum函数累加
sum(
when($"prediction".equalTo(1), 1).otherwise(0)
).as("femaleTotal")
)
/*
root
|-- userId: string (nullable = true)
|-- total: long (nullable = false)
|-- maleTotal: long (nullable = true)
|-- femaleTotal: long (nullable = true)
*/
//genderDF.printSchema()
//genderDF.show(10, truncate = false)
// 5. 计算标签
// 5.1 获取属性标签:tagRule和tagName
val rulesMap: Map[String, String] = TagTools.convertMap(tagDF)
val rulesMapBroadcast: Broadcast[Map[String, String]] = session.sparkContext.broadcast(rulesMap)
// 5.2 自定义UDF函数,计算占比,确定标签值
// 对每个用户,分别计算男性商品和女性商品占比,当占比大于等于0.6时,确定购物性别
val gender_tag_udf: UserDefinedFunction = udf(
(total: Long, maleTotal: Long, femaleTotal: Long) => {
// 计算占比
val maleRate: Double = maleTotal / total.toDouble
val femaleRate: Double = femaleTotal / total.toDouble
if(maleRate >= 0.6){ // usg = 男性
rulesMapBroadcast.value("0")
}else if(femaleRate >= 0.6){ // usg =女性
rulesMapBroadcast.value("1")
}else{ // usg = 中性
rulesMapBroadcast.value("-1")
}
}
)
// 5.3 获取画像标签数据
val modelDF: DataFrame = genderDF
.select(
$"userId", //
gender_tag_udf($"total", $"maleTotal", $"femaleTotal").as("usg")
)
//modelDF.printSchema()
//modelDF.show(100, truncate = false)
// 返回画像标签数据
modelDF
}
/**
* 针对数据集进行特征工程:特征提取、特征转换及特征选择
* @param dataframe 数据集
* @return 数据集,包含特征列features: Vector类型和标签列label
*/
def featuresTransform(dataframe: DataFrame): DataFrame = {
/*
// 1. 提取特征,封装值features向量中
// 2. 将标签label索引化
// 3. 将类别特征索引化
*/
// a. 特征向量化
val assembler: VectorAssembler = new VectorAssembler()
.setInputCols(Array("color", "product"))
.setOutputCol("raw_features")
val df1: DataFrame = assembler.transform(dataframe)
// b. 类别特征进行索引
val vectorIndexer: VectorIndexerModel = new VectorIndexer()
.setInputCol("raw_features")
.setOutputCol("features")
.setMaxCategories(30)
.fit(df1)
val df2: DataFrame = vectorIndexer.transform(df1)
// c. 返回特征数据
df2
}
/**
* 使用决策树分类算法训练模型,返回DecisionTreeClassificationModel模型
* @return
*/
def trainModel(dataframe: DataFrame): DecisionTreeClassificationModel = {
// a. 数据划分为训练数据集和测试数据集
val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2), seed = 123L)
// b. 构建决策树分类器
val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setPredictionCol("prediction")
.setImpurity("gini") // 基尼系数
.setMaxDepth(5) // 树的深度
.setMaxBins(32) // 树的叶子数目
// c. 训练模型
logWarning("正在训练模型...................................")
val dtcModel: DecisionTreeClassificationModel = dtc.fit(trainingDF)
// d. 模型评估
val predictionDF: DataFrame = dtcModel.transform(testingDF)
println(s"accuracy = ${modelEvaluate(predictionDF, "accuracy")}")
// e. 返回模型
dtcModel
}
/**
* 模型评估,返回计算分类指标值
* @param dataframe 预测结果的数据集
* @param metricName 分类评估指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
* @return
*/
def modelEvaluate(dataframe: DataFrame, metricName: String): Double = {
// a. 构建多分类分类器
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// 指标名称,
.setMetricName(metricName)
// b. 计算评估指标
val metric: Double = evaluator.evaluate(dataframe)
// c. 返回指标
metric
}
/**
* 使用决策树分类算法训练模型,返回PipelineModel模型
* @return
*/
def trainPipelineModel(dataframe: DataFrame): PipelineModel = {
// 数据划分为训练数据集和测试数据集
val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2), seed = 123)
// a. 特征向量化
val assembler: VectorAssembler = new VectorAssembler()
.setInputCols(Array("color", "product"))
.setOutputCol("raw_features")
// b. 类别特征进行索引
val vectorIndexer: VectorIndexer = new VectorIndexer()
.setInputCol("raw_features")
.setOutputCol("features")
.setMaxCategories(30)
// c. 构建决策树分类器
val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setPredictionCol("prediction")
.setImpurity("gini") // 基尼系数
.setMaxDepth(5) // 树的深度
.setMaxBins(32) // 树的叶子数目
// TODO: 构建Pipeline管道对象,组合模型学习器(算法)和转换器(模型)
val pipeline: Pipeline = new Pipeline()
.setStages(Array(assembler, vectorIndexer, dtc))
// 训练模型,使用训练数据集
val pipelineModel: PipelineModel = pipeline.fit(trainingDF)
// f. 模型评估
val predictionDF: DataFrame = pipelineModel.transform(testingDF)
predictionDF.show(100, truncate = false)
println(s"accuracy = ${modelEvaluate(predictionDF, "accuracy")}")
// 返回模型
pipelineModel
}
/**
* 调整算法超参数,找出最优模型
* @param dataframe 数据集
* @return
*/
def trainBestModel(dataframe: DataFrame): PipelineModel = {
// a. 特征向量化
val assembler: VectorAssembler = new VectorAssembler()
.setInputCols(Array("color", "product"))
.setOutputCol("raw_features")
// b. 类别特征进行索引
val vectorIndexer: VectorIndexer = new VectorIndexer()
.setInputCol("raw_features")
.setOutputCol("features")
.setMaxCategories(30)
// c. 构建决策树分类器
val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
.setFeaturesCol("features")
.setLabelCol("label")
.setPredictionCol("prediction")
.setImpurity("gini") // 基尼系数
.setMaxDepth(5) // 树的深度
.setMaxBins(32) // 树的叶子数目
// 构建Pipeline管道对象,组合模型学习器(算法)和转换器(模型)
val pipeline: Pipeline = new Pipeline()
.setStages(Array(assembler, vectorIndexer, dtc))
// TODO: 创建网格参数对象实例,设置算法中超参数的值
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(dtc.impurity, Array("gini", "entropy"))
.addGrid(dtc.maxDepth, Array(5, 10))
.addGrid(dtc.maxBins, Array(32, 64))
.build()
// f. 多分类评估器
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
.setMetricName("accuracy")
// TODO: 创建训练验证分割实例对象,设置算法、评估器和数据集占比
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(pipeline) // 算法,应用数据训练获取模型
.setEvaluator(evaluator) // 评估器,对训练模型进行评估
.setEstimatorParamMaps(paramGrid) // 算法超参数,自动组合算法超参数,训练模型并评估,获取最佳模型
// 80% of the data will be used for training and the remaining 20% for validation.
.setTrainRatio(0.8)
// 传递数据集,训练模型
val splitModel: TrainValidationSplitModel = trainValidationSplit.fit(dataframe)
// TODO: 获取最佳模型
val pipelineModel: PipelineModel = splitModel.bestModel.asInstanceOf[PipelineModel]
// 返回获取最佳模型
pipelineModel
}
}
object UsgTagModel {
def main(args: Array[String]): Unit = {
val tagModel = new UsgTagModel()
tagModel.executeModel(378L)
}
}
(叠甲:大部分资料来源于黑马程序员,这里只是做一些自己的认识、思路和理解,主要是为了分享经验,如果大家有不理解的部分可以私信我,也可以移步【黑马程序员_大数据实战之用户画像企业级项目】https://www.bilibili.com/video/BV1Mp4y1x7y7?p=201&vd_source=07930632bf702f026b5f12259522cb42,以上,大佬勿喷)