Pipeline管道自己的一点理解

概念

基于DataFrame API机器学习库Spark ML中提供Pipeline管道构建模型,方便实际项目中开发与部署。通俗地解释,就是一条流水线,它将多个处理步骤或组件串联起来,形成一个有序的工作流程。

特点

        流水线式设计

                Pipeline管道模型就像一个工厂中的流水线,原材料(数据)进入流水线的一端,经过一系列的加工处理(如数据清洗、转换、特征提取等),最终从流水线的另一端输出成品(处理后的数据或结果)。

        提升效率

                通过流水线式的设计,Pipeline管道模型可以显著提升数据处理的效率。在传统的数据处理过程中,可能需要手动编写大量的代码来执行每一步操作,并且需要手动管理数据在不同步骤之间的传递。而Pipeline管道模型将这些步骤封装在一起,自动管理数据的传递和处理,从而减少了重复劳动和出错的可能性。

        降低复杂度

                对于复杂的数据处理任务,Pipeline管道模型可以将任务拆分成多个简单的步骤,每个步骤只负责完成一部分工作。这种分而治之的策略可以降低任务的复杂度,使得每个步骤都更加容易理解和实现。

        灵活性和可扩展性

                Pipeline管道模型具有良好的灵活性和可扩展性。用户可以根据需要添加或删除步骤,也可以替换某个步骤的实现方式。这种灵活性使得Pipeline管道模型能够适应各种不同的数据处理需求。

应用领域

具体来说,Pipeline管道模型在以下场景中有着广泛的应用:

        机器学习中的数据预处理

                在机器学习任务中,数据预处理通常包括数据清洗、转换、特征提取等多个步骤。通过Pipeline管道模型,可以将这些步骤串联起来,形成一个完整的预处理流程,提高数据处理的效率和准确性。

        Redis客户端的批处理

                在Redis客户端中,使用Pipeline可以批量执行多个命令,减少网络交互的次数,从而提高性能。例如,当有多个命令需要被“及时的”提交,并且这些命令对响应结果没有互相依赖时,可以使用Pipeline进行批处理。

        网络编程中的事件处理

                在网络编程中,ChannelPipeline是一个处理或拦截channel的进站事件和出站事件的双向链表。通过增加或删除ChannelHandler,可以实现对不同业务逻辑的处理,提高网络编程的灵活性和可扩展性。

简单举个栗子:

        机器学习中的数据预处理Pipeline

                数据清洗:Pipeline的第一个步骤通常是数据清洗,它负责处理原始数据中的缺失值、异常值和重复数据等问题。例如,在一个预测房价的机器学习项目中,数据清洗可能包括删除含有缺失值的记录、填充缺失值(如使用均值、中位数等)、处理异常值(如设置阈值进行裁剪)等步骤。

                特征工程:接下来是特征工程步骤,它涉及特征选择、特征转换和特征降维等技术。在这个例子中,特征工程可能包括选择对房价预测有重要影响的特征(如房屋面积、位置、房龄等)、进行特征缩放(如归一化、标准化)、创建新的特征(如房屋面积与房龄的比值)等。

                模型训练和评估:经过特征工程后,数据被输入到选定的机器学习模型中进行训练和评估。训练过程中,可以使用交叉验证等技术来评估模型的性能,并通过调整超参数来优化模型。

                自动化和集成:通过Pipeline,这些步骤被整合在一起,形成一个自动化的数据预处理和模型训练流程。这样,当新的数据到来时,只需要将数据输入到Pipeline中,就可以自动完成数据预处理、特征工程、模型训练和评估等步骤,大大提高了工作效率。

        网络编程中的ChannelPipeline

                事件处理:在网络编程中,ChannelPipeline是一个处理或拦截channel的进站事件和出站事件的双向链表。当一个网络事件发生时(如数据接收、连接建立等),它会在ChannelPipeline中传递和处理。

                添加和删除Handler:用户可以在ChannelPipeline中添加或删除ChannelHandler来处理不同的事件和业务逻辑。例如,在一个TCP服务器中,可以添加一个Decoder Handler来解码接收到的字节数据为应用层消息,然后添加一个业务处理Handler来处理这些消息。

                灵活性:由于ChannelPipeline的灵活性,用户可以轻松地修改事件处理流程,以适应不同的业务需求和场景。例如,可以添加新的Handler来处理新的业务逻辑,或者删除不再需要的Handler来简化流程。

        Redis客户端的Pipeline批处理

                批量操作:在Redis客户端中,Pipeline允许用户将多个命令打包成一个请求发送给Redis服务器,从而减少网络交互的次数,提高性能。例如,假设有一个需求需要同时设置10000个键值对,使用Pipeline可以将这10000个命令打包成一个请求发送,而不是发送10000个单独的请求。

                性能提升:通过减少网络交互的次数,Pipeline可以显著提高Redis客户端的性能。在上面的例子中,使用Pipeline可以大大减少网络延迟和传输时间,从而加快键值对的设置速度。

                管道Pipeline概念:将多个Transformer转换器和Estimators模型学习器按照以来顺序组工作流WorkFlow形式,方便数据集的特征转换和模型训练以及预测。

概念图

转变成

关键词

        DataFrame:数据框,一种数据结构,来源于Spark SQL中,DataFrame = Dataset[Row],存储要训练和测试的数据集;

        Transformer:转换器,一种算法Algorithm,必须实现transform方法。比如:模型Model就是一个转换器,将输入的数据集DataFrame转换为预测结果的数据集DataFrame;

        Estimator:估计器或者模型学习器,将数据集DataFrame转换为一个Transformer,实现fit()方法,输入一个DataFrame并产生一个Model,即一个Transformer(转换器);

        Pipeline:管道,

        Parameter:参数,无论是转换器Transformer还是模型学习器Estimator都是一个算法,使用算法的时候必然会有参数。

一个经典的机器学习构建包括若干个过程:

以上四个步骤可以抽象为一个包括多个步骤的流水线式工作,从数据收集开始,至输出需要的最终结果。因此,对以上多个步骤进行抽象建模,简化为流水线式工作流程,则存在着可行性,对利用SparkMLlib进行机器学习的用户来说,流水线式机器学习比单个步骤独立建模更加高效、易用。

一个pipeline在结构上会包含以一个或多个Stage,每一个Stage都会完成一个认为,如数据集处理转化,模型训练,参数设置或数据预测等,这样的Stage在ML里按照处理问题类型的不同都有相应的定义和实现。两个主要的stage为Transformer和Estimator。

Transformer模型

        用来操作一个DataFrame数据并生成另一个DataFrame数据,比如svm模型,、一个特征提取工具,都可以抽象为一个Transformer。

Estimator算法

        用来做模型拟合,生成一个Transformer。

 

Pipeline官方案例

将【官方案例】决策树分类代码,改为使用Pipeline构建模型与预测,流程示意图如下:

 

代码实现 

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object PipelineClassification {
	
	def main(args: Array[String]): Unit = {
		
		val spark = SparkSession.builder()
			.appName(this.getClass.getSimpleName.stripSuffix("$"))
			.master("local[4]")
			.getOrCreate()
		import org.apache.spark.sql.functions._
		import spark.implicits._
		
		// 1. 加载数据
		val dataframe: DataFrame = spark.read
			.format("libsvm")
			.load("datas/mllib/sample_libsvm_data.txt")
		//dataframe.printSchema()
		//dataframe.show(10, truncate = false)
		
		// 划分数据集:训练数据和测试数据
		val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))
		
		// 2. 特征工程:特征提取、特征转换及特征选择
		// 2.1. 将标签值label,转换为索引,从0开始,到 K-1
		val labelIndexer: StringIndexerModel = new StringIndexer()
			.setInputCol("label")
			.setOutputCol("index_label")
			.fit(dataframe)
		val df1: DataFrame = labelIndexer.transform(dataframe)
		
		// 2.2. 对类别特征数据进行特殊处理, 当每列的值的个数小于等于设置K,那么此列数据被当做类别特征,自动进行索引转换
		val featureIndexer: VectorIndexerModel = new VectorIndexer()
			.setInputCol("features")
			.setOutputCol("index_features")
			// TODO: 表示哪些特征列为类别特征列,并且将特征列的特征值进行索引化转换操作
			.setMaxCategories(4) // 类别特征最大类别个数
			.fit(df1)
		val df2: DataFrame = featureIndexer.transform(df1)
		
		// 3. 使用决策树算法构建分类模型
		val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
			.setLabelCol("index_label")
			.setFeaturesCol("index_features")
			// 设置决策树算法相关超参数
			.setImpurity("gini") // 也可以是香农熵:entropy
			.setMaxDepth(5)
			.setMaxBins(32) // 此值必须大于等于类别特征类别个数
		
		// TODO: 4. 构建Pipeline管道,设置Stage,每个Stage要么是算法(模型学习器Estimator),要么是模型(转换器Transformer)
		val pipeline: Pipeline = new Pipeline()
			// 设置Stage,依赖顺序
    		.setStages(
			    Array(labelIndexer, featureIndexer, dtc)
		    )
		// 应用数据集,训练管道模型
		val pipelineModel: PipelineModel = pipeline.fit(trainingDF)
		// TODO: 获取决策树分类模型
		val dtcModel: DecisionTreeClassificationModel = pipelineModel
			.stages(2)
			.asInstanceOf[DecisionTreeClassificationModel]
		println(dtcModel.toDebugString)
		
		// 4. 模型评估
		val predictionDF: DataFrame = pipelineModel.transform(testingDF)
		predictionDF.printSchema()
		predictionDF
			.select($"index_label", $"probability", $"prediction")
			.show(20, truncate = false)
		val evaluator = new MulticlassClassificationEvaluator()
			.setLabelCol("index_label")
			.setPredictionCol("prediction")
			.setMetricName("accuracy")
		val accuracy: Double = evaluator.evaluate(predictionDF)
		println(s"Accuracy = $accuracy")
		
		spark.stop()
	}
	
}

扩展补充部分

将刚刚学到的Pipeline模型写入上一篇USG用户购物性别标签中,完善代码:

import cn.itcast.tags.models.{AbstractModel, ModelType}
import cn.itcast.tags.tools.{MLModelTools, TagTools}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.functions._

/**
 * 挖掘类型标签模型开发:用户购物性别标签模型
 */
class UsgTagModel extends AbstractModel("用户购物性别USG", ModelType.ML){
	/*
	378	用户购物性别
		379	男		0
		380	女		1
		381	中性		-1
	 */
	override def doTag(businessDF: DataFrame, tagDF: DataFrame): DataFrame = {
		val session: SparkSession = businessDF.sparkSession
		import session.implicits._
		
		/*
			root
			 |-- cordersn: string (nullable = true)
			 |-- ogcolor: string (nullable = true)
			 |-- producttype: string (nullable = true)
		 */
		//businessDF.printSchema()
		//businessDF.show(10, truncate = false)
		
		//tagDF.printSchema()
		/*
			+---+----+----+-----+
			|id |name|rule|level|
			+---+----+----+-----+
			|379|男   |0   |5    |
			|380|女   |1   |5    |
			|381|中性  |-1  |5    |
			+---+----+----+-----+
		 */
		//tagDF.filter($"level".equalTo(5)).show(10, truncate = false)
		
		
		// 1. 获取订单表数据tbl_tag_orders,与订单商品表数据关联获取会员ID
		val ordersDF: DataFrame = spark.read
			.format("hbase")
			.option("zkHosts", "bigdata-cdh01.itcast.cn")
			.option("zkPort", "2181")
			.option("hbaseTable", "tbl_tag_orders")
			.option("family", "detail")
			.option("selectFields", "memberid,ordersn")
			.load()
		//ordersDF.printSchema()
		//ordersDF.show(10, truncate = false)
		
		// 2. 加载维度表数据:tbl_dim_colors(颜色)、tbl_dim_products(产品)
		// 2.1 加载颜色维度表数据
		val colorsDF: DataFrame = {
			spark.read
				.format("jdbc")
				.option("driver", "com.mysql.jdbc.Driver")
				.option("url",
					"jdbc:mysql://bigdata-cdh01.itcast.cn:3306/?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC")
				.option("dbtable", "profile_tags.tbl_dim_colors")
				.option("user", "root")
				.option("password", "123456")
				.load()
		}
		/*
			root
			 |-- id: integer (nullable = false)
			 |-- color_name: string (nullable = true)
		 */
		//colorsDF.printSchema()
		//colorsDF.show(30, truncate = false)
		
		// 2.2. 构建颜色WHEN语句
		val colorColumn: Column = {
			// 声明变量
			var colorCol: Column = null
			colorsDF
				.as[(Int, String)].rdd
				.collectAsMap()
				.foreach{case (colorId, colorName) =>
					if(null == colorCol){
						colorCol = when($"ogcolor".equalTo(colorName), colorId)
					}else{
						colorCol = colorCol.when($"ogcolor".equalTo(colorName), colorId)
					}
				}
			colorCol = colorCol.otherwise(0).as("color")
			//  返回
			colorCol
		}
		
		// 2.3 加载商品维度表数据
		val productsDF: DataFrame = {
			spark.read
				.format("jdbc")
				.option("driver", "com.mysql.jdbc.Driver")
				.option("url",
					"jdbc:mysql://bigdata-cdh01.itcast.cn:3306/?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC")
				.option("dbtable", "profile_tags.tbl_dim_products")
				.option("user", "root")
				.option("password", "123456")
				.load()
		}
		/*
			root
			 |-- id: integer (nullable = false)
			 |-- product_name: string (nullable = true)
		 */
		//productsDF.printSchema()
		///productsDF.show(30, truncate = false)
		
		// 2.4. 构建商品类别WHEN语句
		var productColumn: Column = {
			// 声明变量
			var productCol: Column = null
			productsDF
				.as[(Int, String)].rdd
				.collectAsMap()
				.foreach{case (productId, productName) =>
					if(null == productCol){
						productCol = when($"producttype".equalTo(productName), productId)
					}else{
						productCol = productCol.when($"producttype".equalTo(productName), productId)
					}
				}
			productCol = productCol.otherwise(0).as("product")
			// 返回
			productCol
		}
		
		// 2.5. 根据运营规则标注的部分数据
		val labelColumn: Column = {
			when($"ogcolor".equalTo("樱花粉")
				.or($"ogcolor".equalTo("白色"))
				.or($"ogcolor".equalTo("香槟色"))
				.or($"ogcolor".equalTo("香槟金"))
				.or($"productType".equalTo("料理机"))
				.or($"productType".equalTo("挂烫机"))
				.or($"productType".equalTo("吸尘器/除螨仪")), 1) //女
				.otherwise(0)//男
				.alias("label")//决策树预测label
		}
		
		// 3. 关联订单数据,颜色维度和商品类别维度数据
		val goodsDF: DataFrame = businessDF
				// 关联订单数据:
	            .join(ordersDF, businessDF("cordersn") === ordersDF("ordersn"))
				// 选择所需字段,使用when判断函数
	            .select(
				    $"memberid".as("userId"), //
				    colorColumn, // 颜色ColorColumn
				    productColumn, // 产品类别ProductColumn
				    // 依据规则标注商品性别
				    labelColumn
			    )
		//goodsDF.printSchema()
		//goodsDF.show(100, truncate = false)
		
		// TODO: 直接使用标注数据,给用户打标签
		//val predictionDF: DataFrame = goodsDF.select($"userId", $"label".as("prediction"))
		
		// TODO: 实际上,需要使用标注数据(给每个用户购买每个商品打上性别字段),构建算法模型,并使用模型预测值性别字段值
		//val featuresDF: DataFrame = featuresTransform(goodsDF)
		//val dtcModel: DecisionTreeClassificationModel = trainModel(featuresDF)
		//val predictionDF: DataFrame = dtcModel.transform(featuresDF)
		
		// TODO: 使用数据集训练管道模型PipelineModel
		//val pipelineModel: PipelineModel = trainPipelineModel(goodsDF)
		//val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
		
		// TODO: 使用训练验证分割方式,设置超参数训练处最佳模型
		//val pipelineModel: PipelineModel = trainBestModel(goodsDF)
		//val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
		
		// TODO: 使用交叉验证方式,设置超参数训练获取最佳模型
		/*
			当模型存在时,直接加载模型;如果不存在,训练获取最佳模型,并保存
		 */
		val pipelineModel: PipelineModel = MLModelTools.loadModel(
			goodsDF, "usg", this.getClass
		).asInstanceOf[PipelineModel]
		val predictionDF: DataFrame = pipelineModel.transform(goodsDF)
		
		// 4. 按照用户ID分组,统计每个用户购物男性或女性商品个数及占比
		val genderDF: DataFrame = predictionDF
    		.groupBy($"userId")
    		.agg(
			    count($"userId").as("total"), // 某个用户购物商品总数
			    // 判断label为0时,表示为男性商品,设置为1,使用sum函数累加
			    sum(
				    when($"prediction".equalTo(0), 1).otherwise(0)
			    ).as("maleTotal"),
			    // 判断label为1时,表示为女性商品,设置为1,使用sum函数累加
			    sum(
				    when($"prediction".equalTo(1), 1).otherwise(0)
			    ).as("femaleTotal")
		    )
		/*
		root
			 |-- userId: string (nullable = true)
			 |-- total: long (nullable = false)
			 |-- maleTotal: long (nullable = true)
			 |-- femaleTotal: long (nullable = true)
		 */
		//genderDF.printSchema()
		//genderDF.show(10, truncate = false)
		
		// 5. 计算标签
		// 5.1 获取属性标签:tagRule和tagName
		val rulesMap: Map[String, String] = TagTools.convertMap(tagDF)
		val rulesMapBroadcast: Broadcast[Map[String, String]] = session.sparkContext.broadcast(rulesMap)
		
		// 5.2 自定义UDF函数,计算占比,确定标签值
		// 对每个用户,分别计算男性商品和女性商品占比,当占比大于等于0.6时,确定购物性别
		val gender_tag_udf: UserDefinedFunction = udf(
			(total: Long, maleTotal: Long, femaleTotal: Long) => {
				// 计算占比
				val maleRate: Double = maleTotal / total.toDouble
				val femaleRate: Double = femaleTotal / total.toDouble
				if(maleRate >= 0.6){ // usg = 男性
					rulesMapBroadcast.value("0")
				}else if(femaleRate >= 0.6){ // usg =女性
					rulesMapBroadcast.value("1")
				}else{ // usg = 中性
					rulesMapBroadcast.value("-1")
				}
			}
		)
		
		// 5.3 获取画像标签数据
		val modelDF: DataFrame = genderDF
			.select(
				$"userId", //
				gender_tag_udf($"total", $"maleTotal", $"femaleTotal").as("usg")
			)
		//modelDF.printSchema()
		//modelDF.show(100, truncate = false)
		
		// 返回画像标签数据
		modelDF
	}
	
	
	/**
	 * 针对数据集进行特征工程:特征提取、特征转换及特征选择
	 * @param dataframe 数据集
	 * @return 数据集,包含特征列features: Vector类型和标签列label
	 */
	def featuresTransform(dataframe: DataFrame): DataFrame = {
		/*
			// 1. 提取特征,封装值features向量中
			// 2. 将标签label索引化
			// 3. 将类别特征索引化
		 */
		// a. 特征向量化
		val assembler: VectorAssembler = new VectorAssembler()
			.setInputCols(Array("color", "product"))
			.setOutputCol("raw_features")
		val df1: DataFrame = assembler.transform(dataframe)
		
		// b. 类别特征进行索引
		val vectorIndexer: VectorIndexerModel = new VectorIndexer()
			.setInputCol("raw_features")
			.setOutputCol("features")
			.setMaxCategories(30)
			.fit(df1)
		val df2: DataFrame = vectorIndexer.transform(df1)
		
		// c. 返回特征数据
		df2
	}
	
	/**
	 * 使用决策树分类算法训练模型,返回DecisionTreeClassificationModel模型
	 * @return
	 */
	def trainModel(dataframe: DataFrame): DecisionTreeClassificationModel = {
		// a. 数据划分为训练数据集和测试数据集
		val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2), seed = 123L)
		// b. 构建决策树分类器
		val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
			.setFeaturesCol("features")
			.setLabelCol("label")
			.setPredictionCol("prediction")
			.setImpurity("gini") // 基尼系数
			.setMaxDepth(5) // 树的深度
			.setMaxBins(32) // 树的叶子数目
		// c. 训练模型
		logWarning("正在训练模型...................................")
		val dtcModel: DecisionTreeClassificationModel = dtc.fit(trainingDF)
		// d. 模型评估
		val predictionDF: DataFrame = dtcModel.transform(testingDF)
		println(s"accuracy = ${modelEvaluate(predictionDF, "accuracy")}")
		// e. 返回模型
		dtcModel
	}
	
	/**
	 * 模型评估,返回计算分类指标值
	 * @param dataframe 预测结果的数据集
	 * @param metricName 分类评估指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
	 * @return
	 */
	def modelEvaluate(dataframe: DataFrame, metricName: String): Double = {
		// a. 构建多分类分类器
		val evaluator = new MulticlassClassificationEvaluator()
			.setLabelCol("label")
			.setPredictionCol("prediction")
			// 指标名称,
			.setMetricName(metricName)
		// b. 计算评估指标
		val metric: Double = evaluator.evaluate(dataframe)
		// c. 返回指标
		metric
	}
	
	/**
	 * 使用决策树分类算法训练模型,返回PipelineModel模型
	 * @return
	 */
	def trainPipelineModel(dataframe: DataFrame): PipelineModel = {
		// 数据划分为训练数据集和测试数据集
		val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2), seed = 123)
		
		// a. 特征向量化
		val assembler: VectorAssembler = new VectorAssembler()
			.setInputCols(Array("color", "product"))
			.setOutputCol("raw_features")
		
		// b. 类别特征进行索引
		val vectorIndexer: VectorIndexer = new VectorIndexer()
			.setInputCol("raw_features")
			.setOutputCol("features")
			.setMaxCategories(30)
		
		// c. 构建决策树分类器
		val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
			.setFeaturesCol("features")
			.setLabelCol("label")
			.setPredictionCol("prediction")
			.setImpurity("gini") // 基尼系数
			.setMaxDepth(5) // 树的深度
			.setMaxBins(32) // 树的叶子数目
		
		// TODO: 构建Pipeline管道对象,组合模型学习器(算法)和转换器(模型)
		val pipeline: Pipeline = new Pipeline()
    		.setStages(Array(assembler, vectorIndexer, dtc))
		// 训练模型,使用训练数据集
		val pipelineModel: PipelineModel = pipeline.fit(trainingDF)
		
		// f. 模型评估
		val predictionDF: DataFrame = pipelineModel.transform(testingDF)
		predictionDF.show(100, truncate = false)
		println(s"accuracy = ${modelEvaluate(predictionDF, "accuracy")}")
		
		// 返回模型
		pipelineModel
	}
	
	/**
	 * 调整算法超参数,找出最优模型
	 * @param dataframe 数据集
	 * @return
	 */
	def trainBestModel(dataframe: DataFrame): PipelineModel = {
		// a. 特征向量化
		val assembler: VectorAssembler = new VectorAssembler()
			.setInputCols(Array("color", "product"))
			.setOutputCol("raw_features")
		
		// b. 类别特征进行索引
		val vectorIndexer: VectorIndexer = new VectorIndexer()
			.setInputCol("raw_features")
			.setOutputCol("features")
			.setMaxCategories(30)
		
		// c. 构建决策树分类器
		val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
			.setFeaturesCol("features")
			.setLabelCol("label")
			.setPredictionCol("prediction")
			.setImpurity("gini") // 基尼系数
			.setMaxDepth(5) // 树的深度
			.setMaxBins(32) // 树的叶子数目
		
		// 构建Pipeline管道对象,组合模型学习器(算法)和转换器(模型)
		val pipeline: Pipeline = new Pipeline()
			.setStages(Array(assembler, vectorIndexer, dtc))
		
		// TODO: 创建网格参数对象实例,设置算法中超参数的值
		val paramGrid: Array[ParamMap] = new ParamGridBuilder()
			.addGrid(dtc.impurity, Array("gini", "entropy"))
			.addGrid(dtc.maxDepth, Array(5, 10))
			.addGrid(dtc.maxBins, Array(32, 64))
			.build()
		
		// f. 多分类评估器
		val evaluator = new MulticlassClassificationEvaluator()
			.setLabelCol("label")
			.setPredictionCol("prediction")
			// 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
			.setMetricName("accuracy")
		
		// TODO: 创建训练验证分割实例对象,设置算法、评估器和数据集占比
		val trainValidationSplit = new TrainValidationSplit()
			.setEstimator(pipeline) // 算法,应用数据训练获取模型
			.setEvaluator(evaluator) // 评估器,对训练模型进行评估
			.setEstimatorParamMaps(paramGrid) // 算法超参数,自动组合算法超参数,训练模型并评估,获取最佳模型
			// 80% of the data will be used for training and the remaining 20% for validation.
			.setTrainRatio(0.8)
		// 传递数据集,训练模型
		val splitModel: TrainValidationSplitModel = trainValidationSplit.fit(dataframe)
		
		// TODO: 获取最佳模型
		val pipelineModel: PipelineModel = splitModel.bestModel.asInstanceOf[PipelineModel]
		
		// 返回获取最佳模型
		pipelineModel
	}
	
}

object UsgTagModel {
	def main(args: Array[String]): Unit = {
		val tagModel = new UsgTagModel()
		tagModel.executeModel(378L)
	}
}

(叠甲:大部分资料来源于黑马程序员,这里只是做一些自己的认识、思路和理解,主要是为了分享经验,如果大家有不理解的部分可以私信我,也可以移步【黑马程序员_大数据实战之用户画像企业级项目】https://www.bilibili.com/video/BV1Mp4y1x7y7?p=201&vd_source=07930632bf702f026b5f12259522cb42,以上,大佬勿喷)

  • 23
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值