利用spark的mllib构建GBDT模型

GBDT模型

GBDT模型的介绍,我主要是参考博客:http://blog.csdn.net/w28971023/article/details/8240756
在这里,我主要归纳以下几点要素:
1.GBDT中的树都是回归树;
2.回归树节点分割点衡量最好的标准是叶子个数的上限;
3.GBDT的核心在于,每个棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量;
4.GB为Gradient Boosting, Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则趋向于0;
5.GBDT采用一个Shrinkage策略,本质上,Shrinkage为每棵树设置了一个weight,累加时要乘以这个weight,但和Gradient并没有关系。

利用spark构建GBDT模型

训练GBDT模型

public void trainModel(){

        //初始化spark
        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
        conf.set("spark.testing.memory","2147480000");
        SparkContext sc = new SparkContext(conf);

        //加载训练文件, 使用MLUtils包
        JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.trainsetFile).toJavaRDD();

        //训练模型, 默认情况下使用均值方差作为阈值标准
        int numIteration = 10;  //boosting提升迭代的次数
        int maxDepth = 3;       //回归树的最大深度
        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
        boostingStrategy.setNumIterations(numIteration);
        boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
        //记录所有特征的连续结果
        Map<Integer, Integer> categoricalFeaturesInfoMap = new HashMap<Integer, Integer>();
        boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfoMap);
        //gdbt模型
        final GradientBoostedTreesModel model = GradientBoostedTrees.train(lpdata, boostingStrategy);
        model.save(sc, modelpath);
        sc.stop();
    }

预测数据

public void predict() {
        //初始化spark
        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
        conf.set("spark.testing.memory","2147480000");
        SparkContext sc = new SparkContext(conf);

        //加载gbdt模型
        final GradientBoostedTreesModel model = GradientBoostedTreesModel.load(sc, this.modelpath);

        //加载测试文件
        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictFile).toJavaRDD();
        testData.cache();


        //预测数据
        JavaRDD<Tuple2<Double, Double>>  predictionAndLabel = testData.map(new Prediction(model)) ;

        //计算所有数据的平均值方差
         Double testMSE = predictionAndLabel.map(new CountSquareError()).reduce(new ReduceSquareError()) / testData.count();
         System.out.println("testData's MSE is : " + testMSE);
         sc.stop();
    }

    static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
        GradientBoostedTreesModel model;
        public Prediction(GradientBoostedTreesModel model){
            this.model = model;
        }
        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
            Double score = model.predict(p.features());
            return new Tuple2<Double , Double>(score, p.label());
        }
    }

    static class CountSquareError implements Function<Tuple2<Double, Double>, Double> {
        public Double call (Tuple2<Double, Double> pl) {
            double diff = pl._1() - pl._2();
            return diff * diff;
        }
    }

    static  class ReduceSquareError implements Function2<Double, Double, Double> {
        public Double call(Double a , Double b){
            return a + b ;
        }
    }

关于具体的代码放至我的github上:https://github.com/Quincy1994/MachineLearning

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值