从这章开始分析spark MLlib的decision tree的源码实现。
首先看下官方给的java使用决策树的例子,路径是/home/yangqiao/codes/spark/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
为了方便,部分解析我将直接在代码上进行注释:
public final class JavaDecisionTree {
public static void main(String[] args) {
String datapath = "data/mllib/sample_libsvm_data.txt";//官方给的数据文件路径
if (args.length == 1) {
datapath = args[0];//可以使用自己的数据文件,作为参数传入主函数即可
} else if (args.length > 1) {//路径非法
System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
System.exit(1);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
//spark的配置信息,这里是为APP命名
JavaSparkContext sc = new JavaSparkContext(sparkConf);
//JavaSparkContext是spark程序的主入口,连接到spark集群,可以用来在集群上创建RDD,交换变量。
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
/*cache()方法使得RDD只驻留在内存
loadLibSVMFile将 LIBSVM格式的二进制标识数据转换成一个RDD[LabeledPoint]*/
// 计算数据中有多少类
Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
@Override public Double call(LabeledPoint p) {
return p.label();
}
}).countByValue().size();
/*map方法是对每一个RDD都执行function,并返回一个新的RDD,在此操作中,函数返回了label值,
接下来使用了RDD的countByValue方法,该方法返回不同值的 (value, count)映射*/
// 设置参数
// Empty categoricalFeaturesInfo indicates all features are continuous.
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
//采用GINI作为分裂指标,最大深度为5,最大叶子节点数目为32
// 训练用于分类的决策树
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// 后面是使用模型进行预测和进行模型评价,稍晚再分析,主要分析模型构建过程
JavaPairRDD<Double, Double> predictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double trainErr =
1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / data.count();
System.out.println("Training error: " + trainErr);
System.out.println("Learned classification tree model:\n" + model);
// Train a DecisionTree model for regression.
impurity = "variance";
final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on training instances and compute training error
JavaPairRDD<Double, Double> regressorPredictionAndLabel =
data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(regressionModel.predict(p.features()), p.label());
}
});
Double trainMSE =
regressorPredictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override public Double call(Tuple2<Double, Double> pl) {
Double diff = pl._1() - pl._2();
return diff * diff;
}
}).reduce(new Function2<Double, Double, Double>() {
@Override public Double call(Double a, Double b) {
return a + b;
}
}) / data.count();
System.out.println("Training Mean Squared Error: " + trainMSE);
System.out.println("Learned regression tree model:\n" + regressionModel);
sc.stop();
}
}
以上代码可以看出,模型构建的核心代码是:
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
也就是说建立模型调用了trainClassifier方法,那么trainClassifier具体有什么呢,就需要深入源码分析。
按照以下路径打开源码文件:
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
先重点分析DecisionTree.scala文件,相关分析将在下一篇博客连载。