通俗来说,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:
女儿:多大年纪了? 母亲:26。 女儿:长的帅不帅? 母亲:挺帅的。 女儿:收入高不? 母亲:不算很高,中等情况。 女儿:是公务员不? 母亲:是,在税务局上班呢。 女儿:那好,我去见见。
(1)训练数据
0 1:32 2:1 3:1 4:0
0 1:25 2:1 3:2 4:0
1 1:29 2:1 3:2 4:1
1 1:24 2:1 3:1 4:0
0 1:31 2:1 3:1 4:0
1 1:35 2:1 3:2 4:1
0 1:30 2:0 3:1 4:0
0 1:31 2:1 3:1 4:0
1 1:30 2:1 3:2 4:1
1 1:21 2:1 3:1 4:0
0 1:21 2:1 3:2 4:0
1 1:21 2:1 3:2 4:1
0 1:29 2:0 3:2 4:1
0 1:29 2:1 3:0 4:1
0 1:29 2:0 3:2 4:1
1 1:30 2:1 3:1 4:0
解释:行动(1 见面 0 不见面)(1:32 年龄:32岁)(2:1 长相:帅或中等)(3:1 收入:高)(4:0 公务员:不是)
(2)测试数据
0 1:32 2:1 3:2 4:0
1 1:27 2:1 3:1 4:1
1 1:29 2:1 3:1 4:0
1 1:25 2:1 3:2 4:1
0 1:23 2:0 3:2 4:1
(3)样例代码
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
String datapath = "jueceshu.txt";
String ceshipath = "jueceshu2.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
JavaRDD<LabeledPoint> ceshi = MLUtils.loadLibSVMFile(jsc.sc(), ceshipath).toJavaRDD();
List<LabeledPoint> take = data.take(3);
for (LabeledPoint labeledPoint : take) {
System.out.println("----->"+labeledPoint.features());
System.out.println("----->"+labeledPoint.label());
}
JavaRDD<LabeledPoint> trainingData = data;
JavaRDD<LabeledPoint> testData = ceshi;
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "variance";
Integer maxDepth = 5;
Integer maxBins = 32;
// Train a DecisionTree model.
final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
System.out.println(predictionAndLabel.take(10));
Double testMSE =predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override
public Double call(Tuple2<Double, Double> pl) {
Double diff = pl._1() - pl._2();
return diff * diff;
}
}).reduce(new Function2<Double, Double, Double>() {
@Override
public Double call(Double a, Double b) {
return a + b;
}
}) / data.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
// Save and load model
//model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
// DecisionTreeModel sameModel = DecisionTreeModel
// .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
// $example off$
}
(4)提交运行
(5)结果查看
----->(4,[0,1,2,3],[32.0,1.0,1.0,0.0])
----->0.0
----->(4,[0,1,2,3],[25.0,1.0,2.0,0.0])
----->0.0
----->(4,[0,1,2,3],[29.0,1.0,2.0,1.0])
----->1.0
[(0.0,0.0), (1.0,1.0), (1.0,1.0), (1.0,1.0), (0.0,0.0)]
Test Mean Squared Error: 0.0
Learned regression tree model:
DecisionTreeModel regressor of depth 4 with 11 nodes
If (feature 1 <= 0.0)
Predict: 0.0
Else (feature 1 > 0.0)
If (feature 3 <= 0.0)
If (feature 0 <= 30.0)
If (feature 2 <= 1.0)
Predict: 1.0
Else (feature 2 > 1.0)
Predict: 0.0
Else (feature 0 > 30.0)
Predict: 0.0
Else (feature 3 > 0.0)
If (feature 2 <= 0.0)
Predict: 0.0
Else (feature 2 > 0.0)
Predict: 1.0