spark mllib源码分析之随机森林(Random Forest)(四)

本文深入探讨Spark MLlib中随机森林的节点分裂过程,包括数据统计、节点分裂的最佳切分点寻找,以及连续和无序、有序特征的处理方法。文章详尽分析了如何计算节点的最优分裂点及其增益,以及如何进行节点的分裂和循环训练。
摘要由CSDN通过智能技术生成

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(五)

6.4. node分裂

逻辑主要在DecisionTree.findBestSplits函数中,是RF训练最核心的部分

DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
6.4.1. 数据统计

数据统计分成两部分,先在各个partition上分别统计,再累积各partition成全局统计。

6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)

取出各node的特征子集,如果不需要抽样则为None;否则返回Map[Int, Array[Int]],其实就是将之前treeToNodeToIndexInfo中的NodeIndexInfo转换为map结构,将其作为广播变量nodeToFeaturesBc。

6.4.1.2. 分区统计

一系列函数的调用链,我们逐层分析

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    } else {
      input.mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOp(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    }

首先对每个partition构造一个DTStatsAggregator数组,长度是node的个数,注意这里实际使用的是数组,node怎样与自己的aggregator的对应?前面我们提到NodeIndexInfo的第一个成员是groupIndex,其值就是node的次序,和这里aggregator数组index其实是对应的,也就是说可以从NodeIndexInfo中取得groupIndex,然后作为数组index取得对应node的agg。DTStatsAggregator的入参是metadata和每个node的特征子集。然后将每个点统计到DTStatsAggregator中,其中调用了内部函数binSeqOp,

 /**
     * Performs a sequential aggregation over a partition.
     *
     * Each data point contributes to one node. For each feature,
     * the aggregate sufficient statistics are updated for the relevant bins.
     *
     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
     *             each (node, feature, bin).
     * @param baggedPoint   Data point being aggregated.
     * @return  agg
     */
    def binSeqOp(
        agg: Array[DTStatsAggregator],
        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    //对每个node
      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
          bins, metadata.unorderedFeatures)
        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
      }

      agg
    }

首先调用函数predictNodeIndex计算nodeIndex,如果是首轮或者叶子节点,直接返回node.id;如果不是首轮,因为传入的是每棵树的root node,就从root node开始,逐渐往下判断该point应该是属于哪个node的,因为我们已经对node进行了分裂,这里其实实现了样本的划分。举个栗子,当前node如果是root的左孩子节点,而point预测节点应该属于右孩子,则调用nodeBinSepOp时就直接返回了,不会将这个point统计进去,用不大的时间换取样本集划分的空间,还是比较巧妙的。

/**
   * Get the node index corresponding to this data point.
   * Thi
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用Java编写Spark MLlib中的随机森林算法的示例代码: ```java import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.RandomForestClassificationModel; import org.apache.spark.ml.classification.RandomForestClassifier; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.IndexToString; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class RandomForestExample { public static void main(String[] args) { // 创建SparkConf对象 SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample").setMaster("local"); // 创建JavaSparkContext对象 JavaSparkContext jsc = new JavaSparkContext(sparkConf); // 创建SQLContext对象 SQLContext sqlContext = new SQLContext(jsc); // 加载数据集 Dataset<Row> data = sqlContext.read().format("csv").option("header", "true").load("path/to/dataset.csv"); // 数据预处理 StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data); VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{"feature1", "feature2", "feature3"}).setOutputCol("features"); Dataset<Row> assembledData = assembler.transform(data); Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3}); Dataset<Row> trainingData = splits[0]; Dataset<Row> testData = splits[1]; // 构建随机森林分类模型 RandomForestClassifier rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("features").setNumTrees(10); VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(trainingData); IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels()); // 训练模型 Dataset<Row> indexedTrainingData = featureIndexer.transform(trainingData); RandomForestClassificationModel model = rf.fit(indexedTrainingData); // 测试模型 Dataset<Row> indexedTestData = featureIndexer.transform(testData); Dataset<Row> predictions = model.transform(indexedTestData); predictions.select("predictedLabel", "label", "features").show(10); // 评估模型 MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error = " + (1.0 - accuracy)); // 关闭JavaSparkContext对象 jsc.stop(); } } ``` 其中,我们首先加载数据集并进行预处理,然后构建随机森林分类模型,使用训练数据训练模型,使用测试数据测试模型,并计算模型的准确率,最后关闭JavaSparkContext对象。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值