spark的官方demo中并没有告诉我们修改决策树的内存,所以自己琢磨了一下,把那个配置文件的demo用java给弄了出来,代码如下:
导入的包
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.impurity.Gini$;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
整体的代码:
SparkConf conf = new SparkConf().setAppName("test_spark").set("spark.executor.memory", "30m")
// .setMaster("spark://spark-master:7077");
.setMaster("local");
// SparkConf conf = new SparkConf().setAppName("LR");
// conf.set("spark.executor.memory", "1024m");
// conf.set("spark.cores.max", "2");
// conf.set("spark.driver.allowMultipleContexts", "true");
// String ML_MASTER = StringUtil.getProperty("ML_MASTER");
// conf.setMaster(ML_MASTER);
JavaSparkContext sc = new JavaSparkContext(conf);
LabeledPoint pos1 = new LabeledPoint(1, Vectors.dense(2.0, 2.0, 2.0));
LabeledPoint pos2 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos3 = new LabeledPoint(1, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos4 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos5 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos6 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos7 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
LabeledPoint pos8 = new LabeledPoint(2, Vectors.dense(1.0, 1.0, 1.0));
List<LabeledPoint> labeledPoints = new ArrayList<>();
labeledPoints.add(pos1);
labeledPoints.add(pos2);
labeledPoints.add(pos3);
labeledPoints.add(pos4);
labeledPoints.add(pos5);
labeledPoints.add(pos6);
labeledPoints.add(pos7);
labeledPoints.add(pos8);
JavaRDD<LabeledPoint> b = sc.parallelize(labeledPoints);
RDD<LabeledPoint> aa = b.rdd();
scala.collection.immutable.Map<Object, Object> cate = new scala.collection.immutable.HashMap<Object, Object>();
Map<Object, Object> categoricalFeaturesInfo = new HashMap<Object, Object>();
int maxDepth = 5;
int maxBins = 32;
Gini$ gini = Gini.instance();
Strategy strategy = new Strategy(org.apache.spark.mllib.tree.configuration.Algo.Classification(), gini, maxDepth, 3, maxBins, org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort(), cate, 1, 0.0, 512, 0.1, false, 0);
DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);
System.out.println("model decision:" + model.toDebugString());
结果应该是如下所示:
model decision:DecisionTreeModel classifier of depth 0 with 1 nodes
Predict: 2.0
方法二: 下面这样的strategy的修改内存也是可以的,上面的例子出现的结果有点怪异,所以又研究了其他的实现方法,下面的是能正确得到结果的初始化方法。
Strategy strategy = Strategy.defaultStategy(org.apache.spark.mllib.tree.configuration.Algo.Classification());
strategy.setNumClasses(5);
strategy.setMaxMemoryInMB(512);
DecisionTreeModel model = DecisionTree.train(b.rdd(), strategy);
这边的作用仅仅是为了看他是否能运行。各位需要自己修改的地方看下官网文档吧。
这个是1.5.2的api,大家对着自己的spark版本看吧。
https://spark.apache.org/docs/1.5.2/api/java/org/apache/spark/mllib/tree/configuration/Strategy.html