Spark GBDT

随机森林(random forest)和GBDT都是属于集成学习(ensemble learning)的范畴。集成学习下有两个重要的策略Bagging和Boosting。
  Bagging算法是这样做的:每个分类器都随机从原样本中做有放回的采样,然后分别在这些采样后的样本上训练分类器,然后再把这些分类器组合起来。简单的多数投票一般就可以。其代表算法是随机森林。Boosting的意思是这样,他通过迭代地训练一系列的分类器,每个分类器采用的样本分布都和上一轮的学习结果有关。其代表算法是AdaBoost, GBDT。

    val conf=new SparkConf().setAppName("GBDTExample")
    val sc=new SparkContext(conf)
    val sqlcontext=new SQLContext(sc)
    import sqlcontext.implicits._
    val data = MLUtils.loadLibSVMFile(sc,"/tmp/sample_libsvm_data.txt").toDF("label","features")


    val splits=data.randomSplit(Array(0.7,0.3))
    val (trainData,testData)=(splits(0),splits(1))

    val labelIndexer = new StringIndexer()
      .setInputCol("lable")
      .setOutputCol("indexLable")
      .fit(data)

    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data)

    val gbdt=new GBTClassifier()
        .setLabelCol("indexLable")
        .setFeaturesCol("indexedFeatures")
        .setMaxIter(10)

    val lableConvert=new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictionLable")
      .setLabels(labelIndexer.labels)

    val pipeline=new Pipeline()
      .setStages(Array(labelIndexer,featureIndexer,gbdt,lableConvert))

    val model=pipeline.fit(trainData)

    val predications=model.transform(testData)

    predications.select("predictionLable", "label", "features").show(5)

    val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]

    println("Learned classification GBT model:\n" + gbtModel.toDebugString)


  }

运行结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值