继续binsToBestSplit,这个函数是找到一个节点的最优划分
// Calculate InformationGain and ImpurityStats if current node is top node 如果当前节点是根节点计算信息增益和不纯性统计
val level = LearningNode.indexToLevel(node.id) //indexToLevel是根据node.id计算level,这个level是从0开始
var gainAndImpurityStats: ImpurityStats = if (level ==0) {null} else { node.stats}//如果是根节点增益和不纯性统计是null,因为构造跟节点的时候只放了nodeIndex,不是根节点就是node.stats,这说明后面会计算stats
// For each (feature, split), calculate the gain, and select the best (feature, split).//对每个属性的每个划分,计算增益,并选择最优的属性划分
val (bestSplit, bestSplitStats) = //返回最优划分,最优划分的统计
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>//循环节点的每个属性
val featureIndex = if (featuresForNode.nonEmpty) { //找到每个属性的索引
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val numSplits = binAggregates.metadata.numSplits(featureIndex)//求出这个属性有多少划分
if (binAggregates.metadata.isContinuous(featureIndex)) {//下面分三种情况计算,连续属性,unordered属性,ordered属性
// Cumulative sum (scanLeft) of bin statistics. //聚合bin统计,这样每个bin的聚合信息就是包括这个bin和之前bin的聚合信息
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)//计算属性索引的偏移
var splitIndex = 0
while (splitIndex < numSplits) {//遍历每个划分
binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)//把前面的bin聚合的后面的bin上,改变的是后面的bin,这样就实现了注释说的功能
splitIndex += 1
}
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =//寻找最优划分,返回最优属性划分索引和最优属性增益统计
Range(0, numSplits).map { case splitIdx =>//对每个划分
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)//获取属性划分的不纯性值,这时统计值已经被更新成累计值了
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)//获取属性总的不纯性值
rightChildStats.subtract(leftChildStats)//不纯性相减,结果更新到rightChildStats
下面我们先看下calculateImpurityStats这个方法,就是计算不纯性的
private def calculateImpurityStats(
stats: ImpurityStats,//不纯性统计,如果是根节点是null,这是我们上面提到的
leftImpurityCalculator: ImpurityCalculator,//左不纯性统计
rightImpurityCalculator: ImpurityCalculator,//右不纯性统计,根据前面计算结果已经是不纯性的差了
metadata: DecisionTreeMetadata): ImpurityStats = {//元数据
val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {//父节点不纯性统计
leftImpurityCalculator.copy.add(rightImpurityCalculator)//如果是根节点把左右相加,其实还是最初numSplit对应的不纯性统计
} else {
stats.impurityCalculator//否则取传过来的不纯性统计
}
val impurity: Double = if (stats == null) {
parentImpurityCalculator.calculate()//根节点计算不纯性值,计算公式就是 val squaredLoss = sumSquares - (sum * sum) / count,返回squaredLoss / count,用的都是统计信息
} else {
stats.impurity//否则直接取传过来的不纯性值
}
val leftCount = leftImpurityCalculator.count//左不纯性包含的实例数量
val rightCount = rightImpurityCalculator.count//右不纯性包含实例数量
val totalCount = leftCount + rightCount//总数
// If left child or right child doesn't satisfy minimum instances per node,//如果左右子树不满足最小实例条件,划分无效,返回无效信息增益统计
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)//返回一个无效的不纯性统计
}
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0//获取左子树不纯性
val rightImpurity = rightImpurityCalculator.calculate()//获取右子树不纯性
val leftWeight = leftCount / totalCount.toDouble//左子树权重
val rightWeight = rightCount / totalCount.toDouble//右子树权重
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity//计算信息增益
// if information gain doesn't satisfy minimum information gain,//如果信息增益不满足最小增益量,返回无效信息增益统计
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {