spark mllib源码分析之随机森林(Random Forest)(二)

相关文章
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)
spark源码分析之DecisionTree与GBDT

4. 特征处理

这部分主要在DecisionTree.scala的findSplitsBins函数,将所有特征封装成Split,然后装箱Bin。首先对split和bin的结构进行说明

4.1. 数据结构

4.1.1. Split
class Split(
    @Since("1.0.0") feature: Int,
    @Since("1.0.0") threshold: Double,
    @Since("1.0.0") featureType: FeatureType,
    @Since("1.0.0") categories: List[Double])
  • feature:特征id
  • threshold:阈值
  • featureType:连续特征(Continuous)/离散特征(Categorical)
  • categories:离散特征值数组,离散特征使用。放着此split中所有特征值
4.1.2. Bin
class Bin(
    lowSplit: Split, 
    highSplit: Split, 
    featureType: FeatureType, 
    category: Double)
  • lowSplit/highSplit:上下界
  • featureType:连续特征(Continuous)/离散特征(Categorical)
  • category:离散特征的特征值

4.2. 连续特征处理

4.2.1. 抽样
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
      // Calculate the number of samples for approximate quantile calculation.
      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
      val fraction = if (requiredSamples < metadata.numExamples) {
        requiredSamples.toDouble / metadata.numExamples
      } else {
        1.0
      }
      logDebug("fraction of data used for calculating quantiles = " + fraction)
      input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
    } else {
      input.sparkContext.emptyRDD[LabeledPoint]
    }

首先筛选出连续特征集,然后计算抽样数量,抽样比例,然后无放回样本抽样;如果没有连续特征,则为空RDD

4.2.2. 计算Split
metadata.quantileStrategy match {
      case Sort =>
        findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
      case MinMax =>
        throw new UnsupportedOperationException("minmax not supported yet.")
      case ApproxHist =>
        throw new UnsupportedOperationException("approximate histogram not supported yet.")
    }

分位点策略,这里只实现了Sort这一种,前文有说明,下面的计算在findSplitsBinsBySorting函数中,入参是抽样样本集,metadata和连续特征集(里面是特征id,从0开始,见LabelPoint的构造)

val continuousSplits = {
    // reduce the parallelism for split computations when there are less
    // continuous features than input partitions. this prevents tasks from
    // being spun up that will definitely do no work.
    val numPartitions = math.min(continuousFeatures.length,input.partitions.length)
    input.flatMap(point => continuousFeatures.map(idx =>  (idx,point.features(idx))))
         .groupByKey(numPartitions)
         .map { case (k, v) => findSplits(k, v) }
         .collectAsMap()
    }

特征id为key,value是样本对应的该特征下的所有特征值,传给findSplits函数,其中又调用了findSplitsForContinuousFeature函数获得连续特征的Split,入参为样本,metadata和特征id

def findSplitsForContinuousFeature(
      featureSamples: Array[Double], 
      metadata: DecisionTreeMetadata,
      featureIndex: Int): Array[Double] = {
    require(metadata.isContinuous(featureIndex),
      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

    val splits = {
    //连续特征的split是numBins-1
      val numSplits = metadata.numSplits(featureIndex)
    //统计所有特征值其出现的次数
      // get count for 
  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值