构建环境导入模块
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
spark = SparkSession.builder.config(conf = SparkConf()).getOrCreate()
from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql import Row,functions
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel,\
BinaryLogisticRegressionSummary,LogisticRegression
数据集示例
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
读取鸢尾花数据集
schema = StructType([
StructField("_c0", DoubleType(), True),
StructField("_c1", DoubleType(), True),
StructField("_c2", DoubleType(), True),
StructField("_c3", DoubleType(), True),
StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data.show(5)
+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3| _c4|
+---+---+---+---+-----------+
|5.1|3.5|1.4|0.2|Iris-setosa|
|4.9|3.0|1.4|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|5.0|3.6|1.4|0.2|Iris-setosa|
+---+---+---+---+-----------+
only showing top 5 rows
将特征列转成向量
df_assembler = VectorAssembler(inputCols=['_c0','_c1','_c2',\
'_c3'], outputCol='features')
data = df_assembler.transform(data).select('features','_c4')
data.show(5,truncate=False)
+-----------------+-----------+
|features |_c4 |
+-----------------+-----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|
|[4.9,3.0,1.4,0.2]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|Iris-setosa|
+-----------------+-----------+
only showing top 5 rows
将标签列转成数值
StringIndexer 将一列按照值出现频率的大小转换成数值,例如该列共有5类值,出现最多的将对应转成0.0,其他按出现频率大小依次转为1.0,2.0,3.0,4.0
labelIndexer = StringIndexer().setInputCol("_c4"). \
setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+
| features| _c4|indexedLabel|
+-----------------+-----------+------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa| 0.0|
|[4.9,3.0,1.4,0.2]|Iris-setosa| 0.0|
|[4.7,3.2,1.3,0.2]|Iris-setosa| 0.0|
|[4.6,3.1,1.5,0.2]|Iris-setosa| 0.0|
|[5.0,3.6,1.4,0.2]|Iris-setosa| 0.0|
+-----------------+-----------+------------+
only showing top 5 rows
为向量中特征转成索引
VectorIndexer 与 StringIndexer 类似,会将向量中的值转成索引数值,只不过多了一个参数 maxCategories 只有向量中某一对应列的数值种类不超过这个值就会被转化有编号的离散值(index),如果数值种类超过 maxCategories 则不会转化,按连续变量处理
featureIndexer = VectorIndexer(maxCategories=5).setInputCol("features"). \
setOutputCol("indexedFeatures").fit(data)
data = featureIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+-----------------+
| features| _c4|indexedLabel| indexedFeatures|
+-----------------+-----------+------------+-----------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa| 0.0|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|Iris-setosa| 0.0|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|Iris-setosa| 0.0|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|Iris-setosa| 0.0|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|Iris-setosa| 0.0|[5.0,3.6,1.4,0.2]|
+-----------------+-----------+------------+-----------------+
only showing top 5 rows
分割数据集
trainData, testData = data.randomSplit([0.7, 0.3])
调用logistic模型进行训练预测
testData 中的 prediction便是预测的结果,probability 为概率值
lr = LogisticRegression(labelCol='indexedLabel',featuresCol='indexedFeatures',\
maxIter=100, regParam=0.3, elasticNetParam=0.8).fit(trainData)
testData = lr.transform(testData)
testData .show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|features |_c4 |indexedLabel|indexedFeatures |rawPrediction |probability |prediction|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|[4.6,3.1,1.5,0.2]|Iris-setosa |0.0 |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0 |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0 |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0 |
|[4.9,3.1,1.5,0.1]|Iris-setosa |0.0 |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517] |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0 |
|[5.0,3.0,1.6,0.2]|Iris-setosa |0.0 |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0 |
|[5.0,3.4,1.5,0.2]|Iris-setosa |0.0 |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0 |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
only showing top 5 rows
将预测值转回成标签
IndexToString 与 StringIndexer 相反,会将数值转成标签
labelConverter = IndexToString(inputCol='prediction',outputCol='predictedLabel',\
labels=labelIndexer.labels)
testData = labelConverter.transform(testData)
testData.show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|features |_c4 |indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predictedLabel|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|Iris-setosa |0.0 |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0 |Iris-setosa |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0 |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0 |Iris-setosa |
|[4.9,3.1,1.5,0.1]|Iris-setosa |0.0 |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517] |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0 |Iris-setosa |
|[5.0,3.0,1.6,0.2]|Iris-setosa |0.0 |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0 |Iris-setosa |
|[5.0,3.4,1.5,0.2]|Iris-setosa |0.0 |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0 |Iris-setosa |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
only showing top 5 rows
上面用模型训练需要很多部步骤,在spark 中可以用 Pipeline 将这些步骤集中起来形成一个管道
Pipeline 的应用
### 读取鸢尾花数据集
schema = StructType([
StructField("_c0", DoubleType(), True),
StructField("_c1", DoubleType(), True),
StructField("_c2", DoubleType(), True),
StructField("_c3", DoubleType(), True),
StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
# data.show(5)
labelIndexer = StringIndexer().setInputCol("_c4"). \
setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
# data.show()
trainData, testData = data.randomSplit([0.7, 0.3])
assembler = VectorAssembler(inputCols=['_c0','_c1','_c2','_c3'], outputCol='features')
featureIndexer = VectorIndexer().setInputCol('features'). \
setOutputCol("indexedFeatures")
lr = LogisticRegression().\
setLabelCol("indexedLabel"). \
setFeaturesCol("indexedFeatures"). \
setMaxIter(100). \
setRegParam(0.3). \
setElasticNetParam(0.8)
# print("LogisticRegression parameters:\n" + lr.explainParams())
labelConverter = IndexToString(). \
setInputCol("prediction"). \
setOutputCol("predictedLabel"). \
setLabels(labelIndexer.labels)
lrPipeline = Pipeline(). \
setStages([assembler, featureIndexer, lr, labelConverter])
lrPipelineModel = lrPipeline.fit(trainData)
## 保存模型
lrPipelineModel.save('./data/lr_model')
加载模型并在测试集上预测
### 加载模型
l_r = PipelineModel.load('./data/lr_model')
lrPredictions = l_r.transform(testData)
preRel = lrPredictions.select("predictedLabel","_c4","features","probability")
preRel.show()
evaluator = MulticlassClassificationEvaluator(). \
setLabelCol("indexedLabel"). \
setPredictionCol("prediction")
lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy
+---------------+---------------+-----------------+--------------------+
| predictedLabel| _c4| features| probability|
+---------------+---------------+-----------------+--------------------+
| Iris-setosa| Iris-setosa|[4.4,3.0,1.3,0.2]|[0.55438709514625...|
| Iris-setosa| Iris-setosa|[4.5,2.3,1.3,0.3]|[0.54484378407858...|
| Iris-setosa| Iris-setosa|[4.6,3.2,1.4,0.2]|[0.54764212881478...|
| Iris-setosa| Iris-setosa|[4.6,3.6,1.0,0.2]|[0.57449259710402...|
| Iris-setosa| Iris-setosa|[4.8,3.0,1.4,0.1]|[0.55711960386173...|
| Iris-setosa| Iris-setosa|[4.8,3.4,1.9,0.2]|[0.51370352584700...|
| Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
| Iris-setosa| Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|Iris-versicolor|Iris-versicolor|[5.0,2.0,3.5,1.0]|[0.33266831839072...|
| Iris-setosa| Iris-setosa|[5.0,3.3,1.4,0.2]|[0.54764212881478...|
| Iris-setosa| Iris-setosa|[5.1,3.5,1.4,0.3]|[0.53807472588121...|
| Iris-setosa| Iris-setosa|[5.1,3.8,1.9,0.4]|[0.49437428087769...|
| Iris-setosa| Iris-setosa|[5.2,3.4,1.4,0.2]|[0.54764212881478...|
| Iris-setosa| Iris-setosa|[5.3,3.7,1.5,0.2]|[0.54087961843218...|
| Iris-setosa| Iris-setosa|[5.4,3.4,1.7,0.2]|[0.52731180331887...|
|Iris-versicolor|Iris-versicolor|[5.5,2.4,3.8,1.1]|[0.30612463894390...|
|Iris-versicolor|Iris-versicolor|[5.6,2.5,3.9,1.1]|[0.30036420997267...|
|Iris-versicolor|Iris-versicolor|[5.6,2.7,4.2,1.3]|[0.26721917673024...|
|Iris-versicolor|Iris-versicolor|[5.6,3.0,4.1,1.3]|[0.27259161091433...|
| Iris-virginica|Iris-versicolor|[5.6,3.0,4.5,1.5]|[0.23632948715024...|
+---------------+---------------+-----------------+--------------------+
only showing top 20 rows
0.8417582417582418
下面一步一步去转换数据,最后用加载管道模型一步出结果
schema = StructType([
StructField("_c0", DoubleType(), True),
StructField("_c1", DoubleType(), True),
StructField("_c2", DoubleType(), True),
StructField("_c3", DoubleType(), True),
StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data = data.drop('_c4')
data.show(5)
data1 = lrPipelineModel.stages[0].transform(data)
data1.show(5)
data2 = lrPipelineModel.stages[1].transform(data1)
data2.show(5)
data3 = lrPipelineModel.stages[2].transform(data2)
data3.show(5)
data4 = lrPipelineModel.stages[3].transform(data3)
data4.show(5)
l_r = PipelineModel.load('./data/lr_model')
lrPredictions = l_r.transform(data)
lrPredictions.show(5)
+---+---+---+---+
|_c0|_c1|_c2|_c3|
+---+---+---+---+
|5.1|3.5|1.4|0.2|
|4.9|3.0|1.4|0.2|
|4.7|3.2|1.3|0.2|
|4.6|3.1|1.5|0.2|
|5.0|3.6|1.4|0.2|
+---+---+---+---+
only showing top 5 rows
+---+---+---+---+-----------------+
|_c0|_c1|_c2|_c3| features|
+---+---+---+---+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+
only showing top 5 rows
+---+---+---+---+-----------------+-----------------+
|_c0|_c1|_c2|_c3| features| indexedFeatures|
+---+---+---+---+-----------------+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+-----------------+
only showing top 5 rows
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3| features| indexedFeatures| rawPrediction| probability|prediction|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...| 0.0|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...| 0.0|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
only showing top 5 rows
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3| features| indexedFeatures| rawPrediction| probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...| 0.0| Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...| 0.0| Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3| features| indexedFeatures| rawPrediction| probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...| 0.0| Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...| 0.0| Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...| 0.0| Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows