Pyspark_Ml_决策树_RF_GBT

spark_ML_决策树_随机森林_梯度提升树

分类模型

Mllib支持常见的机器学习分类模型:
逻辑回归,SoftMax回归,决策树,随机森林,梯度提升树,线性支持向量机,朴素贝叶斯,One-Vs-Rest,以及多层感知机模型。这些模型的接口使用方法基本大同小异,下面仅仅列举常用的决策树,随机森林和梯度提升树的使用作为示范。

1,分类代码实现
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 载入数据
dfdata = spark.read.format("libsvm").load("data/sample_libsvm_data.txt")
(dftrain, dftest) = dfdata.randomSplit([0.7, 0.3])

# 对label进行序号标注,将字符串换成整数序号
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(dfdata)

# 处理分类特征,类别如果超过4将视为连续值
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(dfdata)

# 构建一个决策树模型
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")

# 构建流水线
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])

# 训练流水线
model = pipeline.fit(dftrain)

dfpredictions = model.transform(dftest)

dfpredictions.select("prediction", "indexedLabel", "features").show(5)

# 评估模型误差
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(dfpredictions)
print("Test Error = %g " % (1.0 - accuracy))
treeModel = model.stages[2]
print(treeModel)

+----------+------------+--------------------+
|prediction|indexedLabel|            features|
+----------+------------+--------------------+
|       1.0|         1.0|(692,[98,99,100,1...|
|       1.0|         1.0|(692,[124,125,126...|
|       1.0|         1.0|(692,[124,125,126...|
|       1.0|         1.0|(692,[125,126,127...|
|       1.0|         1.0|(692,[126,127,128...|
+----------+------------+--------------------+
only showing top 5 rows

Test Error = 0.037037 
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_5711dbfcd91e, depth=2, numNodes=5, numClasses=2, numFeatures=692

2,随机森林
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 载入数据
dfdata = spark.read.format("libsvm").load("data/sample_libsvm_data.txt")
(dftrain, dftest) = dfdata.randomSplit([0.7, 0.3])

# 对label进行序号标注,将字符串换成整数序号
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(dfdata)

# 处理类别特征
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(dfdata)


# 使用随机森林模型
rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10)

# 将label重新转换成字符串
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",
                               labels=labelIndexer.labels)

# 构建流水线
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])

# 训练流水线
model = pipeline.fit(dftrain)

# 进行预测
dfpredictions = model.transform(dftest)

dfpredictions.select("predictedLabel", "label", "features").show(5)

# 评估模型
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(dfpredictions)
print("Test Error = %g" % (1.0 - accuracy))

rfModel = model.stages[2]
print(rfModel)  

+--------------+-----+--------------------+
|predictedLabel|label|            features|
+--------------+-----+--------------------+
|           0.0|  0.0|(692,[122,123,124...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
+--------------+-----+--------------------+
only showing top 5 rows

Test Error = 0
RandomForestClassificationModel: uid=RandomForestClassifier_9d8f7dfec86b, numTrees=10, numClasses=2, numFeatures=692

4,梯度提升树
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 载入数据
dfdata = spark.read.format("libsvm").load("data/sample_libsvm_data.txt")
(dftrain, dftest) = dfdata.randomSplit([0.7, 0.3])

# 对label进行序号标注,将字符串换成整数序号
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(dfdata)

# 处理类别特征
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(dfdata)

# 使用梯度提升树模型
gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=20)

# 构建流水线
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])

# 训练流水线
model = pipeline.fit(dftrain)

# 进行预测
dfpredictions = model.transform(dftest)
dfpredictions.select("prediction", "indexedLabel", "features").show(5)

# 评估模型
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(dfpredictions)
print("Test Error = %g" % (1.0 - accuracy))

gbtModel = model.stages[2]
print(gbtModel)  

+----------+------------+--------------------+
|prediction|indexedLabel|            features|
+----------+------------+--------------------+
|       1.0|         1.0|(692,[95,96,97,12...|
|       1.0|         1.0|(692,[98,99,100,1...|
|       1.0|         1.0|(692,[122,123,148...|
|       1.0|         1.0|(692,[124,125,126...|
|       1.0|         1.0|(692,[124,125,126...|
+----------+------------+--------------------+
only showing top 5 rows

Test Error = 0.0689655
GBTClassificationModel: uid = GBTClassifier_e3d7713552b3, numTrees=20, numClasses=2, numFeatures=692
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Elvis_hui

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值