决策树分类是一个非概率模型,测试数据集用的是网上公开的泰坦尼克号乘客数据,用决策树DecisionTreeClassifier的数据挖掘算法来通过三个参数,Pclass,Sex,Age,三个参数来预测乘客的获救率。
pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.penngo.spark.ml</groupId>
<artifactId>sparkml</artifactId>
<packaging>jar</packaging>
<version>1.0-SNAPSHOT</version>
<name>sparkml</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.2.3</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.2.3</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.2.3</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
<version>2.2.3</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.7.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
</plugins>
</build>
</project>
DecisionTreeClassification.java
package com.penngo.spark.ml.main;
import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import org.apache.spark.sql.functions;
import static org.apache.spark.sql.types.DataTypes.DoubleType;
/**
* spark 决策树分类 DecisionTreeClassifier
*
*/
public class DecisionTreeClassification {
private static Logger log = Logger.getLogger(DecisionTreeClassification.class);
private static SparkSession spark = null;
public static void initSpark(){
if (spark == null) {
String os = System.getProperty("os.name").toLowerCase();
// linux上运行
if(os.indexOf("windows") == -1){
spark = SparkSession
.builder()
.appName("DecisionTreeClassification")
.getOrCreate();
}
// window上运行,本机调试
else{
System.setProperty("hadoop.home.dir", "D:/hadoop/hadoop-2.7.6");
System.setProperty("HADOOP_USER_NAME", "hadoop");
spark = SparkSession
.builder()
.appName("DecisionTreeClassification" ).master("local[3]")
.getOrCreate();
}
}
log.warn("spark.conf().getAll()=============" + spark.conf().getAll());
}
public static void run(){
String dataPath = new File("").getAbsolutePath() + "/data/titanic.txt";
Dataset<Row> data = spark.read().option("header", "true").csv(dataPath);
data.show();
//data.describe()
//Dataset<Row> datana2 = data.na().fill(ImmutableMap.of("age", "30", "ticket", "1111"));
Dataset<Row> meanDataset = data.select(functions.mean("age").as("mage"));
Double mage = meanDataset.first().getAs("mage");
// 字符串转换为数据,处理空值
Dataset<Row> data1 = data.select(
functions.col("user_id"),
functions.col("survived").cast(DoubleType).as("label"),
functions.when(functions.col("pclass").equalTo("1st"), 1)
.when(functions.col("pclass").equalTo("2nd"), 2)
.when(functions.col("pclass").equalTo("3rd"), 3)
.cast(DoubleType).as("pclass1"),
functions.when(functions.col("age").equalTo("NA"), mage.intValue()).otherwise(functions.col("age")).cast(DoubleType).as("age1"),
functions.when(functions.col("sex").equalTo("female"), 0).otherwise(1).as("sex")
);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"pclass1", "age1", "sex"})
.setOutputCol("features");
Dataset<Row> data2 = assembler.transform(data1);
data2.show();
// 索引标签,将元数据添加到标签列中
StringIndexerModel labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data2);
// 自动识别分类的特征,并对它们进行索引
// 具有大于5个不同的值的特征被视为连续。
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
//.setMaxCategories(3)
.fit(data2);
// 将数据分为训练和测试集(30%进行测试)
Dataset<Row>[] splits = data2.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];
// 训练决策树模型
DecisionTreeClassifier dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures");
//.setImpurity("entropy") // Gini不纯度,entropy熵
//.setMaxBins(100) // 离散化"连续特征"的最大划分数
//.setMaxDepth(5) // 树的最大深度
//.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1]
//.setMinInstancesPerNode(10) //每个节点包含的最小样本数
//.setSeed(123456)
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels());
// Chain indexers and tree in a Pipeline.
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});
// 训练模型
PipelineModel model = pipeline.fit(trainingData);
// 预测数据
Dataset<Row> predictions = model.transform(testData);
predictions.select("user_id", "features", "label", "prediction").show();
//predictions.select("predictedLabel", "label", "features").show(5);
// 计算错误率
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy");
double accuracy = evaluator.evaluate(predictions);
System.out.println("Test Error = " + (1.0 - accuracy));
// 查看决策树
DecisionTreeClassificationModel treeModel =
(DecisionTreeClassificationModel) (model.stages()[2]);
System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
// $example off$
spark.stop();
}
public static void main(String[] args){
initSpark();
run();
}
}
基础数据
过滤、特征化后的数据
预测结果
预测错误率和预测模型