本章重点关注分裂节点的划分。
具体执行的代码为:DecisionTree.findSplitsBins() :
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
logDebug("isMulticlass = " + metadata.isMulticlass)
val numFeatures = metadata.numFeatures //特征数
// Sample the input only if there are continuous features.
val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) //判断特征的属性值是否连续
val sampledInput = if (hasContinuousFeatures) { //连续的情况
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) //采用子样本来确定分裂的位置,最少为10000个样本
val fraction = if (requiredSamples < metadata.numExamples) {
requiredSamples.toDouble / metadata.numExamples //计算采样比例
} else {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
} else { //离散的情况
new Array[LabeledPoint](0)
}
metadata.quantileStrategy match { //分裂点策略,目前只有一种:带排序的
case Sort =>
val splits = new Array[Array[Split]](numFeatures) //每个特征对应一组分裂位置
val bins = new Array[Array[Bin]](numFeatures) //可以看做和分裂位置一一对应的一个箱子,里面存放一些分裂位置信息
// Find all splits.
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
if (metadata.isContinuous(featureIndex)) { //特征属性值为连续的情况
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) //对应featureIndex特征索引的属性值
val featureSplits = findSplitsForContinuousFeature(featureSamples, //得到每个特征索引的分裂点位置(有多个)
metadata, featureIndex)
val numSplits = featureSplits.length //这个特征索引对应的分裂位置的个数
val numBins = numSplits + 1 //对应的箱子数,(例如:切一个西瓜两刀,那么会得到三块西瓜)
logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits) //分裂点数组
bins(featureIndex) = new Array[Bin](numBins) //对应的箱子数组
var splitIndex = 0
while (splitIndex < numSplits) {
val threshold = featureSplits(splitIndex) //每个分裂点对应的值,因为是带排序的,所以可以看做是一个阈值
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List()) //每个特征索引对应一组(多个)分裂位置信息
splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue) //创建最左边的分裂位置(也就是最小的阈值,为0)
splitIndex = 1
while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), //箱子里将存放的是两个分裂点位置阈值区间的属性值
Continuous, Double.MinValue)
splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) //最后的那个箱子信息
} else { //特征属性值为离散的情况
val numSplits = metadata.numSplits(featureIndex) //分裂点的个数
val numBins = metadata.numBins(featureIndex) //箱子个数
// Categorical feature
val featureArity = metadata.featureArity(featureIndex) //获得特征索引的离散属性个数
if (metadata.isUnordered(featureIndex)) { //特征离散且无序的情况
// TODO: The second half of the bins are unused. Actually, we could just use
// splits and not build bins for unordered features. That should be part of
// a later PR since it will require changing other code (using splits instead
// of bins in a few places).
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity) //提取特征的属性值,返回集合包含其中一个或多个的离散属性值
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(splitIndex) = {
if (splitIndex == 0) {
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
Categorical,
Double.MinValue)
} else {
new Bin(
splits(featureIndex)(splitIndex - 1),
splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
splitIndex += 1
}
} else { //特征离散且有序的情况,不需要提前计算,在训练期间创建
// Ordered features
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
bins(featureIndex) = new Array[Bin](0)
}
}
featureIndex += 1
}
(splits, bins)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}
接下来,看看怎么确定分裂点的位置(特征属性值为连续的情况):
private[tree] def findSplitsForContinuousFeature( //计算特征索引的分裂点位置
featureSamples: Array[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits = {
val numSplits = metadata.numSplits(featureIndex) //分裂数
// get count for each distinct value
val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1)) //得到每个离散属性值及出现的次数:map(value,count)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray //将特征索引对应的每个离散属性值进行排序,升序
// if possible splits is not enough or just enough, just return all possible splits
val possibleSplits = valueCounts.length
if (possibleSplits <= numSplits) {
valueCounts.map(_._1) //每个离散属性值对应一个分裂点
} else { //否则,将会合并部分离散属性值到一个分裂点
// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1) //每个分裂点之前至少要包含的样本数
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
val splits = new ArrayBuffer[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2 //从第一个数组开始扫描,值为特征索引的离散属性值所包含的个数
// targetCount: target value for `currentCount`.
// If `currentCount` is closest value to `targetCount`,
// then current value is a split threshold.
// After finding a split threshold, `targetCount` is added by stride.
var targetCount = stride
while (index < valueCounts.length) {
val previousCount = currentCount
currentCount += valueCounts(index)._2
val previousGap = math.abs(previousCount - targetCount)
val currentGap = math.abs(currentCount - targetCount) //实际上就是当累加数大于targetCount时,就是分裂点的位置
// If adding count of current value to currentCount
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (previousGap < currentGap) {
splits.append(valueCounts(index - 1)._1) //把满足条件的特征索引的值加到splits数组中
targetCount += stride
}
index += 1
}
splits.toArray
}
}
assert(splits.length > 0)
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length) //重新设置(更新)一下相应的信息
splits
}
************* The End *************