Spark-MLlib的快速使用之八(决策树-回归)

 

通俗来说,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:

女儿:多大年纪了? 母亲:26。 女儿:长的帅不帅? 母亲:挺帅的。 女儿:收入高不? 母亲:不算很高,中等情况。 女儿:是公务员不? 母亲:是,在税务局上班呢。 女儿:那好,我去见见。

(1)训练数据

0 1:32 2:1 3:1 4:0

0 1:25 2:1 3:2 4:0

1 1:29 2:1 3:2 4:1

1 1:24 2:1 3:1 4:0

0 1:31 2:1 3:1 4:0

1 1:35 2:1 3:2 4:1

0 1:30 2:0 3:1 4:0

0 1:31 2:1 3:1 4:0

1 1:30 2:1 3:2 4:1

1 1:21 2:1 3:1 4:0

0 1:21 2:1 3:2 4:0

1 1:21 2:1 3:2 4:1

0 1:29 2:0 3:2 4:1

0 1:29 2:1 3:0 4:1

0 1:29 2:0 3:2 4:1

1 1:30 2:1 3:1 4:0

解释:行动(1 见面 0 不见面)(1:32 年龄:32岁)(2:1 长相:帅或中等)(3:1 收入:高)(4:0 公务员:不是)

(2)测试数据

0 1:32 2:1 3:2 4:0

1 1:27 2:1 3:1 4:1

1 1:29 2:1 3:1 4:0

1 1:25 2:1 3:2 4:1

0 1:23 2:0 3:2 4:1

(3)样例代码

public static void main(String[] args) {

SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(sparkConf);

String datapath = "jueceshu.txt";

String ceshipath = "jueceshu2.txt";

JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

JavaRDD<LabeledPoint> ceshi = MLUtils.loadLibSVMFile(jsc.sc(), ceshipath).toJavaRDD();

List<LabeledPoint> take = data.take(3);

for (LabeledPoint labeledPoint : take) {

System.out.println("----->"+labeledPoint.features());

System.out.println("----->"+labeledPoint.label());

}

JavaRDD<LabeledPoint> trainingData = data;

JavaRDD<LabeledPoint> testData = ceshi;

 

// Set parameters.

// Empty categoricalFeaturesInfo indicates all features are continuous.

Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

String impurity = "variance";

Integer maxDepth = 5;

Integer maxBins = 32;

// Train a DecisionTree model.

final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error

JavaPairRDD<Double, Double> predictionAndLabel =

testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

@Override

public Tuple2<Double, Double> call(LabeledPoint p) {

return new Tuple2<Double, Double>(model.predict(p.features()), p.label());

}

});

System.out.println(predictionAndLabel.take(10));

Double testMSE =predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {

@Override

public Double call(Tuple2<Double, Double> pl) {

Double diff = pl._1() - pl._2();

return diff * diff;

}

}).reduce(new Function2<Double, Double, Double>() {

@Override

public Double call(Double a, Double b) {

return a + b;

}

}) / data.count();

System.out.println("Test Mean Squared Error: " + testMSE);

System.out.println("Learned regression tree model:\n" + model.toDebugString());

// Save and load model

//model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");

// DecisionTreeModel sameModel = DecisionTreeModel

// .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");

// $example off$

}

(4)提交运行

 

(5)结果查看

----->(4,[0,1,2,3],[32.0,1.0,1.0,0.0])

----->0.0

----->(4,[0,1,2,3],[25.0,1.0,2.0,0.0])

----->0.0

----->(4,[0,1,2,3],[29.0,1.0,2.0,1.0])

----->1.0

[(0.0,0.0), (1.0,1.0), (1.0,1.0), (1.0,1.0), (0.0,0.0)]

Test Mean Squared Error: 0.0

Learned regression tree model:

DecisionTreeModel regressor of depth 4 with 11 nodes

If (feature 1 <= 0.0)

Predict: 0.0

Else (feature 1 > 0.0)

If (feature 3 <= 0.0)

If (feature 0 <= 30.0)

If (feature 2 <= 1.0)

Predict: 1.0

Else (feature 2 > 1.0)

Predict: 0.0

Else (feature 0 > 30.0)

Predict: 0.0

Else (feature 3 > 0.0)

If (feature 2 <= 0.0)

Predict: 0.0

Else (feature 2 > 0.0)

Predict: 1.0

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值