Apache Spark:Mllib之决策树的操作(java)

当前版本:spark 2.4.6

1. 声明

当前内容主要用于本人学习Spark ML的知识,了解决策树和决策森林,当前内容主要参考Spark高级数据分析的第4章用决策树算法预测森林植被,由于原内容使用scala编写,这里转换为java方式实现

数据准备:数据下载地址

抽掉最后两行数据作为预测数据

2384,170,15,60,5,90,230,245,143,864,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3
2383,165,13,60,4,67,231,244,141,875,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3

2.主要代码

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.RandomForest;
/*import org.apache.spark.mllib.regression.LabeledPoint;*/
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.rdd.RDD;

import scala.Tuple2;

/**
 * 
 * @author hy
 * @createTime 2021-09-11 08:26:26
 * @description 当前内容主要为使用spark科学数据分析中的决策树类型
 * 1. 决策树
 * 2. 决策森林
 *
 */
public class DecisionTreeTest {
	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		jsc.setLogLevel("WARN");
		// operation
		decisionTreeTest(jsc);
		jsc.close();
	}
	
	/**
	 * 
	 * @author hy
	 * @createTime 2021-09-11 12:53:01
	 * @description 将数据都放入,做出决策
	 * @param rawData
	 * @return
	 *
	 */
	private static JavaRDD<LabeledPoint> createLabeledPointDataUsingBefore(JavaRDD<String> rawData) {
		JavaRDD<LabeledPoint> data = rawData.map(new Function<String, LabeledPoint>() {

			@Override
			public LabeledPoint call(String v1) throws Exception {
				String[] strings = v1.split(",");
				double[] values = new double[strings.length];
				for (int i = 0; i < strings.length; i++) {
					values[i] = Double.valueOf(strings[i]);
				}
				double[] noLastNumValues = Arrays.copyOf(values, values.length - 1);
				Vector featureVector = Vectors.dense(noLastNumValues);
				// 最后一个值就是lable ,但是必须比7小(表示的就是类别)
				double label = values[values.length-1] - 1;
				// 决策树要求当前的label从0开始
				LabeledPoint labeledPoint = new LabeledPoint(label,featureVector);
				return labeledPoint;
			}
		});
		return data;
	}
	
	private static int indexOfArray(double[] array,double findValue) {
		int index = -1;
		for (int i = 0; i < array.length; i++) {
			if(findValue==array[i]) {
				index=i;
				break;
			}
		}
		return index;
	}
	
	/**
	 * 
	 * @author hy
	 * @createTime 2021-09-11 13:45:03
	 * @description 使用类型方式
	 * @param rawData
	 * @return
	 *
	 */
	private static JavaRDD<LabeledPoint> createLabeledPointDataUsingType(JavaRDD<String> rawData) {
		JavaRDD<LabeledPoint> data = rawData.map(new Function<String, LabeledPoint>() {

			@Override
			public LabeledPoint call(String v1) throws Exception {
				String[] strings = v1.split(",");
				double[] values = new double[strings.length];
				for (int i = 0; i < strings.length; i++) {
					values[i] = Double.valueOf(strings[i]);
				}
				
				/*
				 * // 得到类型特征:荒地 double[] wildernessValues = Arrays.copyOfRange(values, 10, 14);
				 * double wilderness =indexOfArray(wildernessValues,1.0);
				 * 
				 * // 得到类型特征:土壤 double[] soilValues = Arrays.copyOfRange(values, 14, 54); double
				 * soil =indexOfArray(soilValues,1.0);
				 * 
				 * double[] copyOfRange = Arrays.copyOfRange(values, 0, 10); double[] copyOf =
				 * Arrays.copyOf(copyOfRange, copyOfRange.length+2);
				 * copyOf[copyOf.length-2]=wilderness; copyOf[copyOf.length-1]=soil;
				 */
				
				Vector featureVector = createVectorByLiine(v1);
				// 最后一个值就是lable ,但是必须比7小(表示的就是类别)
				double label = values[values.length-1] - 1;
				// 决策树要求当前的label从0开始
				LabeledPoint labeledPoint = new LabeledPoint(label,featureVector);
				return labeledPoint;
			}
		});
		return data;
	}
	
	
	/**
	 * 
	 * @author hy
	 * @createTime 2021-09-12 08:12:23
	 * @description 将给定的line数据转换为向量数据
	 * @param line
	 * @return
	 *
	 */
	private static Vector createVectorByLiine(String line) {
		String[] strings = line.split(",");
		double[] values = new double[strings.length];
		for (int i = 0; i < strings.length; i++) {
			values[i] = Double.valueOf(strings[i]);
		}
		
		// 得到类型特征:荒地
		double[] wildernessValues = Arrays.copyOfRange(values, 10, 14);
		double wilderness =indexOfArray(wildernessValues,1.0);
		
		// 得到类型特征:土壤
		double[] soilValues = Arrays.copyOfRange(values, 14, 54);
		double soil =indexOfArray(soilValues,1.0);
		
		double[] copyOfRange = Arrays.copyOfRange(values, 0, 10);
		double[] copyOf = Arrays.copyOf(copyOfRange, copyOfRange.length+2);
		copyOf[copyOf.length-2]=wilderness;
		copyOf[copyOf.length-1]=soil;
		
		Vector featureVector = Vectors.dense(copyOf);
		return featureVector;
	}

	private static void decisionTreeTest(JavaSparkContext jsc) {
		JavaRDD<String> rawData = jsc.textFile("C:\\Users\\admin\\Desktop\\mldata\\covtype.data");
		JavaRDD<LabeledPoint> data = createLabeledPointDataUsingType(rawData);

		// map.foreach(x->System.out.println(x));

		// 开始准备训练数(训练数据占80%,交叉检验集和测试集各占10%)
		JavaRDD<LabeledPoint>[] randomSplit = data.randomSplit(new double[] { 0.8, 0.1, 0.1 });

		JavaRDD<LabeledPoint> trainData = randomSplit[0];
		JavaRDD<LabeledPoint> cvData = randomSplit[1];
		JavaRDD<LabeledPoint> testData = randomSplit[2];
		trainData.cache();
		cvData.cache();
		testData.cache();

		// 创建决策树模型(对于具有不同类型的使用trainClassfier,对于使用数值类型使用trainRegressor)
		HashMap<Integer,Integer> hashMap = new HashMap<Integer, Integer>();
		// 4表示最大深度,100表示桶的数量,7 表示集中目标取值的个数,map保存类型特征信息
		// gini代表一种不纯度(不纯度有两种一种是gini另外一种就是熵)
		DecisionTreeModel model = DecisionTree.trainClassifier(trainData, 7, hashMap, "gini", 4, 100);
		// 使用决策森林模型(太慢了)
		//RandomForestModel model = createDecisionRandomForest(trainData);
		
		MulticlassMetrics metrics = getMetrics(model, cvData);
		Matrix confusionMatrix = metrics.confusionMatrix();
		System.out.println(confusionMatrix);
		
		// 下面两个都是准确度,结果是一样的
		System.out.println("准确度:"+metrics.accuracy());	
		System.out.println("精确度:"+metrics.precision());
		
		// 计算每个类别对其他的精确度
		List<Tuple2<Double,Double>> list=new ArrayList<>();
		for (int i = 0; i < 7; i++) {
			Tuple2<Double, Double> tuple2 = new Tuple2<Double,Double>(metrics.precision(i),metrics.recall(i));
			list.add(tuple2);
		}
		
		System.out.println("输出与其他对比精度:");
		list.forEach(x->{System.out.println(x);});
		
		Double[] trainProbablilities = classProbablilities(testData);
		Double[] cvProbablilities = classProbablilities(cvData);
		
		double sum=0.0;
		for (int i = 0; i < cvProbablilities.length; i++) {
			sum +=cvProbablilities[i]*trainProbablilities[i];
		}
		System.out.println("准确的评估值:"+sum);
		
		
		// 计算并获取确定的决策树调优参数(该模型通过设置使用不同的不纯度和桶数量以及决策树的深度方式来实现的,一般是通过循环方式找到最准确度最高的模型进行对testData进行测试)
		DecisionTreeModel newModel = DecisionTree.trainClassifier(trainData.union(cvData), 7, hashMap, "entropy", 20, 300);
		MulticlassMetrics newMetrics = getMetrics(newModel, testData);
		double accuracy = newMetrics.accuracy();
		System.out.println("优化后的决策树对cvData的准确度:"+accuracy);
		
		
		// 决策森林的预测
		/* String input="2709,125,28,67,23,3224,253,207,61,4094,0,29"; */
		// 决策树的预测
		String[] lines= {"2384,170,15,60,5,90,230,245,143,864,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3" ,
				"2383,165,13,60,4,67,231,244,141,875,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3"};
		for (int i = 0; i < lines.length; i++) {
			String line = lines[i];
			Vector vector = createVectorByLiine(line);
			double predict = newModel.predict(vector);
			System.out.println("预测结果为:"+predict);//预测后的标签为2.0,实际需要+1所以预测结果是正确的
		}
		
	}
	
	
	/**
	 * 
	 * @author hy
	 * @createTime 2021-09-11 13:50:08
	 * @description 创建决策森林
	 *
	 */
	private static RandomForestModel createDecisionRandomForest(JavaRDD<LabeledPoint> trainData) {
		Map<Integer, Integer> hashMap=new HashMap<>();
		hashMap.put(10, 4);
		hashMap.put(11, 40);
		// 20 表示构建20个决策树,20分钟都没有执行完毕
		RandomForestModel randomForestModel = RandomForest.trainClassifier(trainData, 7, hashMap, 20, "auto","entropy",30,300, 10);
		return randomForestModel;
	}

	// 计算该类型在所占的比例
	private static Double[] classProbablilities(JavaRDD<LabeledPoint> data) {
		Map<Double, Long> countByValue = data.map(x->x.label()).countByValue();
		List<Tuple2<Double,Long>> counts=new ArrayList<>();
		Long sum = 0L;
		for (Entry<Double, Long> entry : countByValue.entrySet()) {
			counts.add(new Tuple2<Double, Long>(entry.getKey(), entry.getValue()));
			sum+=entry.getValue();
		}
		//System.out.println("sum==>"+sum);
		//System.out.println(counts);
		counts.sort(new Comparator<Tuple2<Double, Long>>() {

			@Override
			public int compare(Tuple2<Double, Long> o1, Tuple2<Double, Long> o2) {
				// TODO Auto-generated method stub
				return (int)(o1._1-o2._1);
			}
		});
		Double[] returnValues=new Double[counts.size()];
		for (int i = 0; i < returnValues.length; i++) {
			returnValues[i]=counts.get(i)._2/(sum*1.0);
		}
		//System.out.println(Arrays.toString(returnValues));
		return returnValues;
		
	}
	
	private static MulticlassMetrics getMetrics(DecisionTreeModel model, JavaRDD<LabeledPoint> data) {
		RDD<Tuple2<Object, Object>> rdd = data.map(new Function<LabeledPoint, Tuple2<Object, Object>>() {

			@Override
			public Tuple2<Object, Object> call(LabeledPoint example) throws Exception {
				double predict = model.predict((org.apache.spark.mllib.linalg.Vector) example.features());
				return new Tuple2<Object, Object>(predict, example.label());
			}
		}).rdd();

		return new MulticlassMetrics(rdd);

	}
	
	private static MulticlassMetrics getMetrics(RandomForestModel model, JavaRDD<LabeledPoint> data) {
		RDD<Tuple2<Object, Object>> rdd = data.map(new Function<LabeledPoint, Tuple2<Object, Object>>() {
			@Override
			public Tuple2<Object, Object> call(LabeledPoint example) throws Exception {
				double predict = model.predict( example.features());
				return new Tuple2<Object, Object>(predict, example.label());
			}
		}).rdd();

		return new MulticlassMetrics(rdd);

	}
}

3. 测试结果

13611.0  7091.0   14.0    0.0  0.0   0.0  326.0   
5215.0   22528.0  532.0   0.0  8.0   0.0  39.0    
0.0      372.0    3255.0  0.0  0.0   0.0  0.0     
0.0      0.0      293.0   0.0  0.0   0.0  0.0     
0.0      913.0    33.0    0.0  12.0  0.0  0.0     
0.0      464.0    1265.0  0.0  0.0   0.0  0.0     
1064.0   25.0     0.0     0.0  0.0   0.0  1051.0  
准确度:0.6962020959887113
精确度:0.6962020959887113
输出与其他对比精度:
(0.6843137254901961,0.6468491588252068)
(0.7176122065428598,0.7954240519737307)
(0.6036721068249258,0.8974358974358975)
(0.0,0.0)
(0.6,0.012526096033402923)
(0.0,0.0)
(0.742231638418079,0.4911214953271028)
准确的评估值:0.37602803637168397
优化后的决策树对cvData的准确度:0.9208749848932166
预测结果为:2.0
预测结果为:2.0

这里本来有决策森林,但是由于决策森林很耗时间(对于单台电脑),所以没有执行操作

预测结果为2.0,这个结果就是label,实际上要+1,所以和上面的是对应的

4. 总结

1. SparkML在进行计算的时候需要Vector向量数据(特征向量)和label数据

2. 通过创建不同的参数模型,来对测试数据进行测试,得到最终的准确度,通过准确度来展现模型的正确性

3. 模型的准确性和输入的参数有关

4. 不明白当前数据向量的获取和定义方式

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值