决策树分类Decision tree classifier

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.sql.functions._

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{ VectorAssembler, IndexToString, StringIndexer, VectorIndexer }


val spark = SparkSession.builder().appName("Spark decision tree classifier").config("spark.some.config.option", "some-value").getOrCreate()

// For implicit conversions like converting RDDs to DataFrames
import spark.implicits._

// 这里仅仅是示例数据,完整的数据源,请参考我的博客http://blog.csdn.net/hadoop_spark_storm/article/details/53412598
val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List(
      (0, "male", 37, 10, "no", 3, 18, 7, 4),
      (0, "female", 27, 4, "no", 4, 14, 6, 4),
      (0, "female", 32, 15, "yes", 1, 12, 1, 4),
      (0, "male", 57, 15, "yes", 5, 18, 6, 5),
      (0, "male", 22, 0.75, "no", 2, 17, 6, 3),
      (0, "female", 32, 1.5, "no", 2, 17, 5, 5))


val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") 
data: org.apache.spark.sql.DataFrame = [affairs: double, gender: string ... 7 more fields]

data.printSchema() 
root 
 |-- affairs: double (nullable = false) 
 |-- gender: string (nullable = true) 
 |-- age: double (nullable = false) 
 |-- yearsmarried: double (nullable = false) 
 |-- children: string (nullable = true) 
 |-- religiousness: double (nullable = false) 
 |-- education: double (nullable = false) 
 |-- occupation: double (nullable = false) 
 |-- rating: double (nullable = false) 


data.show(10,truncate=false)
+-------+------+----+------------+--------+-------------+---------+----------+------+
|affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating|
+-------+------+----+------------+--------+-------------+---------+----------+------+
|0.0    |male  |37.0|10.0        |no      |3.0          |18.0     |7.0       |4.0   |
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |6.0       |4.0   |
|0.0    |female|32.0|15.0        |yes     |1.0          |12.0     |1.0       |4.0   |
|0.0    |male  |57.0|15.0        |yes     |5.0          |18.0     |6.0       |5.0   |
|0.0    |male  |22.0|0.75        |no      |2.0          |17.0     |6.0       |3.0   |
|0.0    |female|32.0|1.5         |no      |2.0          |17.0     |5.0       |5.0   |
|0.0    |female|22.0|0.75        |no      |2.0          |12.0     |1.0       |3.0   |
|0.0    |male  |57.0|15.0        |yes     |2.0          |14.0     |4.0       |4.0   |
|0.0    |female|32.0|15.0        |yes     |4.0          |16.0     |1.0       |2.0   |
|0.0    |male  |22.0|1.5         |no      |4.0          |14.0     |4.0       |5.0   |
+-------+------+----+------------+--------+-------------+---------+----------+------+
only showing top 10 rows

// 查看数据分布情况
data.describe("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating").show(10,truncate=false)
+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+
|summary|affairs           |gender|age              |yearsmarried     |children|religiousness     |education        |occupation       |rating            |
+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+
|count  |601               |601   |601              |601              |601     |601               |601              |601              |601               |
|mean   |1.4559068219633944|null  |32.48752079866888|8.17769550748752 |null    |3.1164725457570714|16.16638935108153|4.194675540765391|3.9317803660565724|
|stddev |3.298757728494681 |null  |9.28876170487667 |5.571303149963791|null    |1.1675094016730692|2.402554565766698|1.819442662708579|1.1031794920503795|
|min    |0.0               |female|17.5             |0.125            |no      |1.0               |9.0              |1.0              |1.0               |
|max    |12.0              |male  |57.0             |15.0             |yes     |5.0               |20.0             |7.0              |5.0               |
+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+

data.createOrReplaceTempView("data")

// 字符类型转换成数值
val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"
labelWhere: String = case when affairs=0 then 0 else cast(1 as double) end as label

val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"
genderWhere: String = case when gender='female' then 0 else cast(1 as double) end as gender

val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children"
childrenWhere: String = case when children='no' then 0 else cast(1 as double) end as children

val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")
dataLabelDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 7 more fields]

val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")
featuresArray: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

// 字段转换成特征向量
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_6e2c6bdd631e

val vecDF: DataFrame = assembler.transform(dataLabelDF)
vecDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 8 more fields]

vecDF.show(10,truncate=false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
only showing top 10 rows


// 索引标签,将元数据添加到标签列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_d00cad619cd5

labelIndexer.transform(vecDF).show(10,truncate=false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |indexedLabel|
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|0.0         |
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |0.0         |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|0.0         |
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|0.0         |
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|0.0         |
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |0.0         |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|0.0         |
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|0.0         |
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|0.0         |
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |0.0         |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
only showing top 10 rows

// 自动识别分类的特征,并对它们进行索引
// 具有大于8个不同的值的特征被视为连续。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(8).fit(vecDF)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_8fbcad97fb60

featureIndexer.transform(vecDF).show(10,truncate=false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |indexedFeatures                   |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|[1.0,37.0,6.0,0.0,2.0,5.0,6.0,3.0]|
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |[0.0,27.0,4.0,0.0,3.0,2.0,5.0,3.0]|
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|[0.0,32.0,7.0,1.0,0.0,1.0,0.0,3.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|[1.0,57.0,7.0,1.0,4.0,5.0,5.0,4.0]|
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|[1.0,22.0,2.0,0.0,1.0,4.0,5.0,2.0]|
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |[0.0,32.0,3.0,0.0,1.0,4.0,4.0,4.0]|
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|[0.0,22.0,2.0,0.0,1.0,1.0,0.0,2.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|[1.0,57.0,7.0,1.0,1.0,2.0,3.0,3.0]|
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|[0.0,32.0,7.0,1.0,3.0,3.0,0.0,1.0]|
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |[1.0,22.0,3.0,0.0,3.0,2.0,3.0,4.0]|
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+
only showing top 10 rows

// 将数据分为训练和测试集(30%进行测试)
val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields]

// 训练决策树模型
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setImpurity("entropy") // 不纯度
.setMaxBins(100) // 离散化"连续特征"的最大划分数
.setMaxDepth(5) // 树的最大深度
.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1]
.setMinInstancesPerNode(10) //每个节点包含的最小样本数 
.setSeed(123456)

// 将索引标签转换回原始标签
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_2598e79a1d08

// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)


// 作出预测
val predictions = model.transform(testData)
predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields]


// 选择几个示例行展示
predictions.select("predictedLabel", "label", "features").show(10,truncate=false)
+--------------+-----+-------------------------------------+
|predictedLabel|label|features                             |
+--------------+-----+-------------------------------------+
|0.0           |0.0  |[0.0,22.0,0.125,0.0,2.0,14.0,4.0,5.0]|
|0.0           |0.0  |[0.0,22.0,0.125,0.0,2.0,16.0,6.0,3.0]|
|0.0           |0.0  |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]|
|0.0           |0.0  |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]|
|0.0           |0.0  |[0.0,22.0,0.75,0.0,2.0,16.0,5.0,5.0] |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,14.0,5.0,4.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,3.0,16.0,6.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0]  |
+--------------+-----+-------------------------------------+
only showing top 10 rows


// 选择(预测标签,实际标签),并计算测试误差。
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")

val accuracy = evaluator.evaluate(predictions)
accuracy: Double = 0.7032967032967034

println("Test Error = " + (1.0 - accuracy))
Test Error = 0.29670329670329665


// 这里的stages(2)中的“2”对应pipeline中的“dt”,将model强制转换为DecisionTreeClassificationModel类型
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModel: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodes

treeModel.getLabelCol
res53: String = indexedLabel

treeModel.getFeaturesCol
res54: String = indexedFeatures

treeModel.featureImportances
res55: org.apache.spark.ml.linalg.Vector = (8,[0,2,3,4,5,6,7],[0.0640344247735859,0.1052957011097811,0.05343872372010684,0.17367191628391196,0.20372870264756315,0.2063093687074741,0.1935211627575769])

treeModel.getPredictionCol
res56: String = prediction

treeModel.getProbabilityCol
res57: String = probability

treeModel.numClasses
res58: Int = 2

treeModel.numFeatures
res59: Int = 8

treeModel.depth
res60: Int = 5

treeModel.numNodes
res61: Int = 33

treeModel.getImpurity
res62: String = entropy

treeModel.getMaxBins
res63: Int = 100

treeModel.getMaxDepth
res64: Int = 5

treeModel.getMaxMemoryInMB
res65: Int = 256

treeModel.getMinInfoGain
res66: Double = 0.01

treeModel.getMinInstancesPerNode
res67: Int = 10

 // 查看决策树
println("Learned classification tree model:\n" + treeModel.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodes
  If (feature 2 in {0.0,1.0,2.0,3.0})
   If (feature 5 in {3.0,6.0})
    Predict: 0.0
   Else (feature 5 not in {3.0,6.0})
    If (feature 4 in {3.0})
     Predict: 0.0
    Else (feature 4 not in {3.0})
     If (feature 3 in {0.0})
      If (feature 6 in {0.0,4.0,5.0})
       Predict: 0.0
      Else (feature 6 not in {0.0,4.0,5.0})
       Predict: 0.0
     Else (feature 3 not in {0.0})
      Predict: 0.0
  Else (feature 2 not in {0.0,1.0,2.0,3.0})
   If (feature 4 in {0.0,1.0,3.0,4.0})
    If (feature 7 in {0.0,1.0,2.0})
     If (feature 6 in {0.0,1.0,6.0})
      If (feature 4 in {1.0,4.0})
       Predict: 0.0
      Else (feature 4 not in {1.0,4.0})
       Predict: 0.0
     Else (feature 6 not in {0.0,1.0,6.0})
      If (feature 7 in {0.0,2.0})
       Predict: 0.0
      Else (feature 7 not in {0.0,2.0})
       Predict: 1.0
    Else (feature 7 not in {0.0,1.0,2.0})
     If (feature 5 in {0.0,1.0})
      Predict: 0.0
     Else (feature 5 not in {0.0,1.0})
      If (feature 6 in {0.0,1.0,2.0,5.0,6.0})
       Predict: 0.0
      Else (feature 6 not in {0.0,1.0,2.0,5.0,6.0})
       Predict: 0.0
   Else (feature 4 not in {0.0,1.0,3.0,4.0})
    If (feature 5 in {0.0,1.0,2.0,3.0,5.0,6.0})
     If (feature 0 in {0.0})
      If (feature 7 in {3.0})
       Predict: 0.0
      Else (feature 7 not in {3.0})
       Predict: 0.0
     Else (feature 0 not in {0.0})
      If (feature 7 in {0.0,2.0,4.0})
       Predict: 0.0
      Else (feature 7 not in {0.0,2.0,4.0})
       Predict: 1.0
    Else (feature 5 not in {0.0,1.0,2.0,3.0,5.0,6.0})
     Predict: 1.0

转载于:https://my.oschina.net/hblt147/blog/1525414

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
快速决策树分类器是一种高效的机器学习算法,用于处理大规模数据集的分类问题。该分类器通过构建一颗决策树来对数据进行分类。它的快速性来自于采用了一些优化技巧。 首先,快速决策树分类器使用了一种高效的特征选择方法,即基尼系数。基尼系数可以评估一个特征对数据集的划分能力,选择具有最佳划分能力的特征作为当前节点的划分依据,从而减少了计算量。 其次,快速决策树分类器采用了剪枝技术,即在构建决策树的过程中,对叶节点进行剪枝,去掉那些没有显著提升分类准确度的叶节点。这样可以避免模型的过拟合,减少了决策树的复杂度。 此外,快速决策树分类器还使用了并行计算技术,可以将数据集划分成多个子集,同时在不同的处理器上进行计算,从而提高了分类器的处理速度。 快速决策树分类器的应用非常广泛。它可以用于文本分类、图像分类、数据挖掘等领域。它的优势在于对大规模数据集的处理速度较快,且具有较好的分类准确度。但是,快速决策树分类器也有一些限制,例如对噪声数据敏感,对缺失的处理能力较弱。 总之,快速决策树分类器是一种高效的分类算法,通过特征选择、剪枝和并行计算等技术优化了分类效率。它在大规模数据集上表现优异,可以广泛应用于各个领域。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值