pyspark 机器学习 逻辑回归 Pipeline

鸢尾花数据集

构建环境导入模块

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

spark = SparkSession.builder.config(conf = SparkConf()).getOrCreate()

from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql import Row,functions
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel,\
BinaryLogisticRegressionSummary,LogisticRegression

数据集示例

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa

读取鸢尾花数据集

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data.show(5)
+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|5.1|3.5|1.4|0.2|Iris-setosa|
|4.9|3.0|1.4|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|5.0|3.6|1.4|0.2|Iris-setosa|
+---+---+---+---+-----------+
only showing top 5 rows

将特征列转成向量

df_assembler = VectorAssembler(inputCols=['_c0','_c1','_c2',\
                                         '_c3'], outputCol='features')
data = df_assembler.transform(data).select('features','_c4')
data.show(5,truncate=False)
+-----------------+-----------+
|features         |_c4        |
+-----------------+-----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|
|[4.9,3.0,1.4,0.2]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|Iris-setosa|
+-----------------+-----------+
only showing top 5 rows

将标签列转成数值

StringIndexer 将一列按照值出现频率的大小转换成数值,例如该列共有5类值,出现最多的将对应转成0.0,其他按出现频率大小依次转为1.0,2.0,3.0,4.0

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+
|         features|        _c4|indexedLabel|
+-----------------+-----------+------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|
+-----------------+-----------+------------+
only showing top 5 rows

为向量中特征转成索引

VectorIndexer 与 StringIndexer 类似,会将向量中的值转成索引数值,只不过多了一个参数 maxCategories 只有向量中某一对应列的数值种类不超过这个值就会被转化有编号的离散值(index),如果数值种类超过 maxCategories 则不会转化,按连续变量处理

featureIndexer = VectorIndexer(maxCategories=5).setInputCol("features"). \
    setOutputCol("indexedFeatures").fit(data)
data = featureIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+-----------------+
|         features|        _c4|indexedLabel|  indexedFeatures|
+-----------------+-----------+------------+-----------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|[5.0,3.6,1.4,0.2]|
+-----------------+-----------+------------+-----------------+
only showing top 5 rows

分割数据集

trainData, testData = data.randomSplit([0.7, 0.3])

调用logistic模型进行训练预测

testData 中的 prediction便是预测的结果,probability 为概率值

lr = LogisticRegression(labelCol='indexedLabel',featuresCol='indexedFeatures',\
                        maxIter=100, regParam=0.3, elasticNetParam=0.8).fit(trainData)
testData = lr.transform(testData)
testData .show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
only showing top 5 rows

将预测值转回成标签

IndexToString 与 StringIndexer 相反,会将数值转成标签

labelConverter = IndexToString(inputCol='prediction',outputCol='predictedLabel',\
                              labels=labelIndexer.labels)
testData = labelConverter.transform(testData)
testData.show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|predictedLabel|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |Iris-setosa   |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |Iris-setosa   |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |Iris-setosa   |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
only showing top 5 rows


上面用模型训练需要很多部步骤,在spark 中可以用 Pipeline 将这些步骤集中起来形成一个管道

Pipeline 的应用

### 读取鸢尾花数据集
schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
# data.show(5)

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
# data.show()

trainData, testData = data.randomSplit([0.7, 0.3])

assembler = VectorAssembler(inputCols=['_c0','_c1','_c2','_c3'], outputCol='features')

featureIndexer = VectorIndexer().setInputCol('features'). \
    setOutputCol("indexedFeatures")


lr = LogisticRegression().\
    setLabelCol("indexedLabel"). \
    setFeaturesCol("indexedFeatures"). \
    setMaxIter(100). \
    setRegParam(0.3). \
    setElasticNetParam(0.8)
# print("LogisticRegression parameters:\n" + lr.explainParams())

labelConverter = IndexToString(). \
    setInputCol("prediction"). \
    setOutputCol("predictedLabel"). \
    setLabels(labelIndexer.labels)
lrPipeline = Pipeline(). \
    setStages([assembler, featureIndexer, lr, labelConverter])

lrPipelineModel = lrPipeline.fit(trainData)

## 保存模型
lrPipelineModel.save('./data/lr_model') 

加载模型并在测试集上预测

### 加载模型
l_r = PipelineModel.load('./data/lr_model')

lrPredictions = l_r.transform(testData)
preRel = lrPredictions.select("predictedLabel","_c4","features","probability")
preRel.show()

evaluator = MulticlassClassificationEvaluator(). \
    setLabelCol("indexedLabel"). \
    setPredictionCol("prediction")
lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy
+---------------+---------------+-----------------+--------------------+
| predictedLabel|            _c4|         features|         probability|
+---------------+---------------+-----------------+--------------------+
|    Iris-setosa|    Iris-setosa|[4.4,3.0,1.3,0.2]|[0.55438709514625...|
|    Iris-setosa|    Iris-setosa|[4.5,2.3,1.3,0.3]|[0.54484378407858...|
|    Iris-setosa|    Iris-setosa|[4.6,3.2,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|[0.57449259710402...|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|[0.55711960386173...|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|[0.51370352584700...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|Iris-versicolor|Iris-versicolor|[5.0,2.0,3.5,1.0]|[0.33266831839072...|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.1,3.5,1.4,0.3]|[0.53807472588121...|
|    Iris-setosa|    Iris-setosa|[5.1,3.8,1.9,0.4]|[0.49437428087769...|
|    Iris-setosa|    Iris-setosa|[5.2,3.4,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|[0.54087961843218...|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.7,0.2]|[0.52731180331887...|
|Iris-versicolor|Iris-versicolor|[5.5,2.4,3.8,1.1]|[0.30612463894390...|
|Iris-versicolor|Iris-versicolor|[5.6,2.5,3.9,1.1]|[0.30036420997267...|
|Iris-versicolor|Iris-versicolor|[5.6,2.7,4.2,1.3]|[0.26721917673024...|
|Iris-versicolor|Iris-versicolor|[5.6,3.0,4.1,1.3]|[0.27259161091433...|
| Iris-virginica|Iris-versicolor|[5.6,3.0,4.5,1.5]|[0.23632948715024...|
+---------------+---------------+-----------------+--------------------+
only showing top 20 rows

0.8417582417582418

下面一步一步去转换数据,最后用加载管道模型一步出结果

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data = data.drop('_c4')
data.show(5)
data1 = lrPipelineModel.stages[0].transform(data)
data1.show(5)
data2 = lrPipelineModel.stages[1].transform(data1)
data2.show(5)
data3 = lrPipelineModel.stages[2].transform(data2)
data3.show(5)
data4 = lrPipelineModel.stages[3].transform(data3)
data4.show(5)

l_r = PipelineModel.load('./data/lr_model')
lrPredictions = l_r.transform(data)
lrPredictions.show(5)
+---+---+---+---+
|_c0|_c1|_c2|_c3|
+---+---+---+---+
|5.1|3.5|1.4|0.2|
|4.9|3.0|1.4|0.2|
|4.7|3.2|1.3|0.2|
|4.6|3.1|1.5|0.2|
|5.0|3.6|1.4|0.2|
+---+---+---+---+
only showing top 5 rows

+---+---+---+---+-----------------+
|_c0|_c1|_c2|_c3|         features|
+---+---+---+---+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|
+---+---+---+---+-----------------+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

viziviuz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值