SparkMLlib Java 决策树分类算法(DecisionTree)

决策树基本理解:

决策树利用树形结构,根据特征一层一层做出判断,会在某一层得到结果。我在其他博客中看到了一副非常好的诠释图:

SparkMLlib Java程序所用数据:

   训练数据:C:\hello\trainData.txt


该数据,逗号前为目标向量,逗号后为特征向量(空格隔开)。
   测试数据:C:\hello\testData.txt

该数据为特征向量,空格隔开。

SparkMLlib DecisionTreeJava程序:

package MLlibTest;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;

import scala.Tuple2;

public class DecisionTreeTest{

	public static void main(String[] args) {
		 SparkConf conf = new SparkConf().setAppName("DecisionTreeTest").setMaster("local[*]");
 	     JavaSparkContext jsc = new JavaSparkContext(conf);
 	     JavaRDD<String> lines = jsc.textFile("C://hello//trainData.txt");
 	     JavaRDD<LabeledPoint> transdata = lines.map(new Function<String,LabeledPoint>(){
 	    	 private static final long serialVersionUID = 1L;
			 @Override	 
 	    	public LabeledPoint call(String str) throws Exception{
				 String[] t1 = str.split(",");
					String[] t2 = t1[1].split(" ");
					LabeledPoint lab = new LabeledPoint(Double.parseDouble(t1[0]),
							Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]), Double.parseDouble(t2[2])));
				return lab;
			}
 	     });
 	     //设置决策树参数,训练模型
 	    Integer numClasses = 3;
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        String impurity = "gini";
        Integer maxDepth = 5;
        Integer maxBins = 32;
        final DecisionTreeModel tree_model = DecisionTree.trainClassifier(transdata, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);
        System.out.println("决策树模型:");  
        System.out.println(tree_model.toDebugString());
        //保存模型
        tree_model.save(jsc.sc(), "C://hello//DecisionTreeModel");
        
        
        //未处理数据,带入模型处理
        JavaRDD<String> testLines = jsc.textFile("C://hello//testData.txt");
        JavaPairRDD<String,String> res = testLines.mapToPair(new PairFunction<String, String, String>() {
        	private static final long serialVersionUID = 1L;
        	@Override
        	public Tuple2<String,String> call(String line) throws Exception{
        		String[] t2 = line.split(" ");
				Vector v = Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]),
						Double.parseDouble(t2[2]));
				double res = tree_model.predict(v);
        		return new Tuple2<String,String>(line,Double.toString(res));
        	}
		}).cache();
        //打印结果
 	    res.foreach(new VoidFunction<Tuple2<String,String>>() {
 	    	private static final long serialVersionUID = 1L;
			 @Override	 
	    	public void call(Tuple2<String,String> a) throws Exception{
				 System.out.println(a._1+" : "+a._2);
			}
		});
 	    //将结果保存在本地
 	    res.saveAsTextFile("C://hello/res");
	}

}

结语:

     做的时间匆忙,错误之处,请大家指出批评,相互学习。



  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值