上篇已经分析到Spark MLlib库的决策树最终实现使用了random forrest的run方法,这篇将对run方法进行详细的剖析和解释。
上篇提到input先被转换成Metadata处理,因此首先看一下buildMetadata方法
可以看出DecisionTreeMetadata确定了叶子节点数在不同情况下的范围,将数据的属性分为了有序和无序两种情况。将二元分类和回归问题放在了一起考虑。
另外,在分割数量上,对于连续数值,先进行抽样,然后分割数目就是分支数减一,对于离散数据,分成有序和无序属性讨论,有序情况使用每个属性的类别数量作为划分(split)数量,无序情况下则使用属性类别数量的子集作为划分依据,解决了属性太多产生大量叶子节点问题。
private[tree] object DecisionTreeMetadata extends Logging {
/**
* 该方法创建一个 [[DecisionTreeMetadata]] 实例
* 对特征的处理分为有序和无序两种情况
*/
def buildMetadata(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
val numFeatures = input.take(1)(0).features.size//属性个数
val numExamples = input.count()//样例总数
val numClasses = strategy.algo match {//最终分类的个数
case Classification => strategy.numClasses//分类问题等于离散类数量
case Regression => 0//回归问题为0,无意义
}
//下面的部分是关键的确定树的叶子节点的数目范围
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
//最大的叶子节点数目不超过样例个数和设定值中的较小值
if (maxPossibleBins < strategy.maxBins) {
logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
s" (= number of training instances)")
}
//categoricalFeaturesInfo是一个映射,记录了每个属性对应的取值个数
//每个属性的最大取值不能大于最大叶子节点数
if (strategy.categoricalFeaturesInfo.nonEmpty) {
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
require(maxCategoriesPerFeature <= maxPossibleBins,
s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
s"in categorical features (= $maxCategoriesPerFeature)")
}
//对于无序属性
val unorderedFeatures = new mutable.HashSet[Int]()
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
if (numClasses > 2) {
// 分类结果多于两个,说明是多元分类
val maxCategoriesForUnorderedFeature =
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
// Decide if some categorical features should be treated as unordered features,
// which require 2 * ((1 << numCategories - 1) - 1) bins.
// We do this check with log values to prevent overflows in case numCategories is large.
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
if (numCategories <= maxCategoriesForUnorderedFeature) {
unorderedFeatures.add(featureIndex)
numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
numBins(featureIndex) = numCategories
}
}
} else {
// 二元分类和回归问题
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
numBins(featureIndex) = numCategories
}
}
// 以下参数都是随机森林使用的,暂时不用
val _featureSubsetStrategy = featureSubsetStrategy match {
case "auto" =>
if (numTrees == 1) {
"all"
} else {
if (strategy.algo == Classification) {
"sqrt"
} else {
"onethird"
}
}
case _ => featureSubsetStrategy
}
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
}
def buildMetadata(
input: RDD[LabeledPoint],
strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
}
/**
* 得到分类属性的元数,即分类数目。返回非有序属性的叶子节点数,数目共计 math.pow(2, arity - 1) - 1 个。每次分割产生两个叶子
*/
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
}