spark gbdt java_Spark.GBDT学习-GBTClassifier

本文介绍了Spark中用于分类的GBTClassifier,它基于Stochastic Gradient Boosting实现,不支持多分类任务。内容涵盖GBTClassifier的类结构、参数设置、训练方法以及模型预测过程。讲解了包括maxDepth、maxBins等参数,以及如何通过train方法训练得到GBTClassificationModel。
摘要由CSDN通过智能技术生成

用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. "Stochastic Gradient Boosting"实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:

本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost

两种方法都是通过最小化损失函数,学习树的集成

TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改

Spark考虑未来实现TreeBoost

GBTClassifier类

定义

一个唯一标识uid,继承了Predictor类,继承了GBTClassifierParams、DefaultParamsWritable、Logging特质。其中Predictor中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型。

class GBTClassifier(override val uid: String)

extends Predictor[Vector, GBTClassifier, GBTClassificationModel]

with GBTClassifierParams with DefaultParamsWritable with Logging

{

def this() = this(Identifiable.randomUID("gbtc"))

...

}

参数

为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。

TreeClassifierParams参数

maxDepth

树的最大深度,0意味着只有一个叶节点,1意味着有一个内部节点+两个叶节点。

支持:>=0

默认:5

maxBins

用于离散连续特征的最大分桶数,用于每个节点特征分裂时分裂点的选择,分桶数越大意味着粒度越高。

支持:>=2并且>=任一类别特征的分类数

默认:32

minInstancesPerNode

分裂后每个子节点含有的最小样本数,如果分裂后左孩子或右孩子含有的样本数少于该值,则该分裂无效。

支持:>=1

默认:1

minInfoGain

树节点分裂时的最小信息增益。

支持:>=0.0

默认:0.0

maxMemoryInMB

每次会对一组节点进行切分,分组是按照树的层次逐步进行。每组需要切分的节点个数视内存大小而定,如果内存太小,每次只能切分一个节点。单位MB

默认:256MB

cacheNodeIds

如果为true,算法会为每个实例缓存树节点ID;如果为false,算法会将树传递给执行器用于匹配实例和树节点。缓存有利于加速训练深度较大的树,用户可以通过参数checkpointInterval设置缓存被检查的频率或者不检查。

默认:false

checkpointInterval

表示缓存的树节点ID的检查频率,当cacheNodeIds为true并且检查目录(checkpoint directory)通过sparkContext设置过才有效。

支持:>=1或者-1代表不检查,10意味着每10次迭代检查一次。

默认:10

impurity

用于计算信息增益的准则。不支持通过GBTClassifier.setImpurity方法设置该值。

支持:entropy、gini

默认:gini

TreeEnsembleParams参数

subsamplingRate

每一次迭代训练基学习器(决策树)时所使用的训练数据集的百分比。

支持:(0, 1]

默认:1.0

seed

随机数种子

默认:this.getClass.getName.hashCode.toLong

GBTParams参数

maxIter

最大迭代次数

支持:>=0

默认:20

stepSize

学习率(learning rate/step size)参数,用于缩小(shrinking)每个基学习器的贡献。

支持:(0, 1]

默认:0.1

GBTClassifierParams参数

lossType

GBT最小化的损失函数,不区分大小写。

支持:logistic

默认:logistic

方法

copy方法

GBTClassifier的拷贝函数。

train方法

GBTClassifier类的主要方法,用于训练得到学习好的用于预测的模型。

// @input: 训练数据, DataSet

// @output: 学习到的模型, GBTClassificationModel

override protected def train(dataset: Dataset[_]):

GBTClassificationModel = {

// 得到类别特征

val categoricalFeatures: Map[Int, Int] =

MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

// 转换训练数据并进行验证

// 将DataSet转换成RDD[LabeledPoint]

// 只支持二分类,要求label为0或1

val oldDataset: RDD[LabeledPoint] =

dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {

case Row(label: Double, features: Vector) =>

require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label. Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")

LabeledPoint(label, features)

}

// 获得特征个数及boosting策略

val numFeatures = oldDataset.first().features.size

val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

// 用于记录日志

val instr = Instrumentation.create(this, oldDataset)

instr.logParams(params: _*)

instr.logNumFeatures(numFeatures)

instr.logNumClasses(2)

// 调用GradientBoostedTrees训练得到一组学习器及其权重

val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))

// 将学到的模型封装成GBTClassificationModel并返回

val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)

instr.logSuccess(m)

m

}

GBTClassifier对象

object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {

// final变量,访问支持的损失函数类型

final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes

// 从目录中加载GBTClassifier

override def load(path: String): GBTClassifier = super.load(path)

}

GBTClassificationModel类

用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。

定义

继承了PredictionModel类以及多个特质,其中PredictionModel的两个元素分别代表特征类型、学习到用于预测的模型。

class GBTClassificationModel private[ml](

override val uid: String,

private val _trees: Array[DecisionTreeRegressionModel],

private val _treeWeights: Array[Double],

override val numFeatures: Int)

extends PredictionModel[Vector, GBTClassificationModel]

with GBTClassifierParams

with TreeEnsembleModel[DecisionTreeRegressionModel]

with MLWritable with Serializable

{

// 检查_trees.nonEmpty

// 检查_trees.length == _treeWeights.length

val numTrees: Int = _trees.length

...

}

方法

transformImpl方法

首先将GBTClassificationModel进行广播,然后通过udf进行预测数据,udf中调用predict方法实现。

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {

// 广播本类

val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)

val predictUDF = udf { (features: Any) =>

// udf通过本类的predict方法实现

bcastModel.value.predict(features.asInstanceOf[Vector])

}

// 使用udf将特征数据转换成预测数据

dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))

}

predict方法

关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。

override protected def predict(features: Vector): Double = {

// 得到每棵树的预测结果

val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)

// 乘以权重之后求和得到融合结果

val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)

// 得到预测lebel

if (prediction > 0.0) 1.0 else 0.0

}

copy方法

GBTClassificationModel的拷贝方法。

toOld方法

将ml的模型转换成mllib中老的API,ml域的私有方法。

private[ml] def toOld: OldGBTModel = {

new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)

}

write方法

调用GBTClassificationModel对象的方法保存本模型。

override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)

GBTClassificationModel对象

fromOld方法

从老的API中转换出当前模型

GBTClassificationModelReader类

私有类,其中的load方法用于从目录中读取模型

GBTClassificationModelWriter类

私有类,其中的saveImpl方法用于将本模型保存

read方法

新建GBTClassificationModelReader类

load方法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值