全面解析Apache Spark中的决策树

决策树是在顺序决策问题进行分类,预测和促进决策的有效方法。决策树由两部分组成:

决策(Desion)
结果(Outcome)

决策树包含三种类型的节点:

根节点(Root node):包含所有数据的树的顶层节点。
分割节点(Splitting node):将数据分配给子组(subgroup)的节点。
终端节点(Terminal node):最终决定(即结果)。

分割节点(Splitting node),就离散数学中的树的概念而言,就是指分支节点。

为了抵达终端结点或者说获得结果,该过程从根节点开始。根据在根节点上做出的决定,选择分支节点。基于在分支节点上做出的决定,选择下一个子分支节点。这个过程继续下去,直到我们到达终端节点,终端节点的值是我们的结果。

Apache Spark中的决策树

Apache Spark中没有决策树的实现可能听起来很奇怪。然而从技术上来说是有的。在Apache Spark中,您可以找到一个随机森林算法的实现,该算法实现可以由用户指定树的数量。因此,Apache Spark使用一棵树来调用随机森林。

在Apache Spark中,决策树是在特征空间上执行递归二进制分割的贪婪算法。树给每个最底部(即叶子结点)分区预测了相同的标签。为了最大化树的节点处的信息增益,通过在一组可能的分支中选择其中的最佳分割来贪婪地选择每个分支结点。

节点不纯度(impurity)是节点上标签一致性的度量。目前的实施提供了两种不纯的分类方法(Gini杂质和熵(Gini impurity and entropy))。

 

 

停止规则

在满足以下列条件之一的情况下,在节点处停止递归树构建(即只要满足一个就停止,译者注):

节点深度等于训练用的 maxDepth 参数。
没有候选的分割结点导致信息收益大于 minInfoGain 。
没有候选的分割结点去产生(至少拥有训练minInstancesPerNode实例)的子节点 。

有用的参数

algo:它可以是分类或回归。
numClasses:分类类的数量。
maxDepth:根据节点定义树的深度。
minInstancesPerNode:对于要进一步拆分的节点,其每个子节点必须至少接收到这样的训练实例数(即实例数必须等于这个参数)。
minInfoGain:对于一个节点进一步拆分,必须满足拆分后至少提高这么多信息量。
maxBins:离散连续特征时使用的bin数。

准备决策树的训练数据

您不能直接向决策树提供任何数据。它需要一种特殊的格式来提供。您可以使用 HashingTF 技术将训练数据转换为标记数据,以便决策树可以理解。这个过程也被称为数据的标准化。

(数据)供给和获得结果

一旦数据被标准化,您就可以提供相同的决策树算法进来行分类。但在此之前,您需要分割数据以用于训练和测试目的; 为了测试的准确性,你需要保留一部分数据进行测试。你可以像这样提供数据:

al splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
// Empty categoricalFeaturesInfo indicates all features are continuous.

val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)

 

在这里,数据是我的标准化输入数据,为了训练和测试目的,我将其分成7:3的比例。我们正在使用最大深度的为5的"gini" 杂质("gini" impurity)。

一旦模型生成,您也可以尝试预测其他数据的分类。但在此之前,我们需要验证最近生成的模型的分类准确性。您可以通过计算"test error"来验证其准确性。

/ Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}

val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println("Test Error = " + testErr)

就是这样!你可以在这里查看一个正在运行的例子。

结语

感谢您的观看,如有不足之处,欢迎批评指正。

如果有对大数据感兴趣的小伙伴或者是从事大数据的老司机可以加群:

658558542    

欢迎大家交流分享,学习交流,共同进步。(里面还有大量的免费资料,帮助大家在成为大数据工程师,乃至架构师的路上披荆斩棘!)

最后祝福所有遇到瓶颈的大数据程序员们突破自己,祝福大家在往后的工作与面试中一切顺利。

转载于:https://my.oschina.net/u/4055005/blog/3000842

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值