spark 决策树分类 DecisionTreeClassifier

决策树分类是一个非概率模型,测试数据集用的是网上公开的泰坦尼克号乘客数据,用决策树DecisionTreeClassifier的数据挖掘算法来通过三个参数,Pclass,Sex,Age,三个参数来预测乘客的获救率。
pom.xml

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>com.penngo.spark.ml</groupId>
  <artifactId>sparkml</artifactId>
  <packaging>jar</packaging>
  <version>1.0-SNAPSHOT</version>
  <name>sparkml</name>
  <url>http://maven.apache.org</url>
  <properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
    <java.version>1.8</java.version>
  </properties>
  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>3.8.1</version>
      <scope>test</scope>
    </dependency>
	<dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-streaming_2.11</artifactId>
      <version>2.2.3</version>
    </dependency>
  </dependencies>
	<build>
    <plugins>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-compiler-plugin</artifactId>
        <version>3.7.0</version>
        <configuration>
          <source>1.8</source>
          <target>1.8</target>
          <encoding>UTF-8</encoding>
        </configuration>
      </plugin>
    </plugins>
    </build>
</project>

DecisionTreeClassification.java

package com.penngo.spark.ml.main;

import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import org.apache.spark.sql.functions;
import static org.apache.spark.sql.types.DataTypes.DoubleType;

/**
 * spark 决策树分类 DecisionTreeClassifier
 *
 */
public class DecisionTreeClassification {
    private static Logger log = Logger.getLogger(DecisionTreeClassification.class);
    private static SparkSession spark = null;



    public static void initSpark(){
        if (spark == null) {
            String os = System.getProperty("os.name").toLowerCase();
            // linux上运行
            if(os.indexOf("windows") == -1){
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification")
                        .getOrCreate();
            }
            // window上运行,本机调试
            else{
                System.setProperty("hadoop.home.dir", "D:/hadoop/hadoop-2.7.6");
                System.setProperty("HADOOP_USER_NAME", "hadoop");
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification" ).master("local[3]")
                        .getOrCreate();
            }
        }
        log.warn("spark.conf().getAll()=============" + spark.conf().getAll());
    }

    public static void run(){
        String dataPath = new File("").getAbsolutePath() + "/data/titanic.txt";
        Dataset<Row> data = spark.read().option("header", "true").csv(dataPath);
        data.show();
        //data.describe()
        //Dataset<Row> datana2 = data.na().fill(ImmutableMap.of("age", "30", "ticket", "1111"));

        Dataset<Row> meanDataset = data.select(functions.mean("age").as("mage"));
        Double mage = meanDataset.first().getAs("mage");
        // 字符串转换为数据,处理空值
        Dataset<Row> data1 = data.select(
                functions.col("user_id"),
                functions.col("survived").cast(DoubleType).as("label"),
                functions.when(functions.col("pclass").equalTo("1st"), 1)
                        .when(functions.col("pclass").equalTo("2nd"), 2)
                        .when(functions.col("pclass").equalTo("3rd"), 3)
                        .cast(DoubleType).as("pclass1"),
                functions.when(functions.col("age").equalTo("NA"), mage.intValue()).otherwise(functions.col("age")).cast(DoubleType).as("age1"),
                functions.when(functions.col("sex").equalTo("female"), 0).otherwise(1).as("sex")
        );

        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"pclass1", "age1", "sex"})
                .setOutputCol("features");
        Dataset<Row> data2 = assembler.transform(data1);
        data2.show();
        // 索引标签,将元数据添加到标签列中
        StringIndexerModel labelIndexer = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data2);
        // 自动识别分类的特征,并对它们进行索引
        // 具有大于5个不同的值的特征被视为连续。
        VectorIndexerModel featureIndexer = new VectorIndexer()
                .setInputCol("features")
                .setOutputCol("indexedFeatures")
                //.setMaxCategories(3)
                .fit(data2);
        // 将数据分为训练和测试集(30%进行测试)
        Dataset<Row>[] splits = data2.randomSplit(new double[]{0.7, 0.3});
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testData = splits[1];

        // 训练决策树模型
        DecisionTreeClassifier dt = new DecisionTreeClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("indexedFeatures");
        //.setImpurity("entropy") // Gini不纯度,entropy熵
        //.setMaxBins(100) // 离散化"连续特征"的最大划分数
        //.setMaxDepth(5) // 树的最大深度
        //.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1]
        //.setMinInstancesPerNode(10) //每个节点包含的最小样本数
        //.setSeed(123456)

        IndexToString labelConverter = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labels());


        // Chain indexers and tree in a Pipeline.
        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});

        // 训练模型
        PipelineModel model = pipeline.fit(trainingData);

        // 预测数据
        Dataset<Row> predictions = model.transform(testData);

        predictions.select("user_id", "features", "label", "prediction").show();
        //predictions.select("predictedLabel", "label", "features").show(5);

        // 计算错误率
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));

        // 查看决策树
        DecisionTreeClassificationModel treeModel =
                (DecisionTreeClassificationModel) (model.stages()[2]);
        System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
        // $example off$

        spark.stop();
    }
    public static void main(String[] args){
        initSpark();
        run();
    }
}

基础数据

过滤、特征化后的数据


预测结果


预测错误率和预测模型

转载于:https://my.oschina.net/penngo/blog/3018547

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值