Spark机器学习
- 本节用Spark提供的机器学习库(Spark MLib)实现线性回归
创建Spark Session,连接到Spark集群
from pyspark.sql import SparkSession
mySpark = SparkSession.builder.appName('My_LR').master('local').getOrCreate()
/usr/local/lib/python3.6/site-packages/pyspark/context.py:238: FutureWarning: Python 3.6 support is deprecated in Spark 3.2.
FutureWarning
myDF = mySpark.read.format("csv").option("inferSchema","true").option("header","true").load("women.csv")
myDF
DataFrame[_c0: int, height: int, weight: int]
显示数据框myDF的前5行,可以用take()来代替方法head()
myDF.head(5)
[Row(_c0=1, height=58, weight=115),
Row(_c0=2, height=59, weight=117),
Row(_c0=3, height=60, weight=120),
Row(_c0=4, height=61, weight=123),
Row(_c0=5, height=62, weight=126)]
myDF.take(5)
[Row(_c0=1, height=58, weight=115),
Row(_c0=2, height=59, weight=117),
Row(_c0=3, height=60, weight=120),
Row(_c0=4, height=61, weight=123),
Row(_c0=5, height=62, weight=126)]
数据理解
查看Spark数据框对象myDF的模式信息
myDF.printSchema()
root
|-- _c0: integer (nullable = true)
|-- height: integer (nullable = true)
|-- weight: integer (nullable = true)
myDF.describe().toPandas().transpose()
0 | 1 | 2 | 3 | 4 | |
---|---|---|---|---|---|
summary | count | mean | stddev | min | max |
_c0 | 15 | 8.0 | 4.47213595499958 | 1 | 15 |
height | 15 | 65.0 | 4.47213595499958 | 58 | 72 |
weight | 15 | 136.73333333333332 | 15.498694261437752 | 115 | 164 |
数据准备
- 定义特征矩阵
from pyspark.ml.feature import VectorAssembler
VectorAssembler = VectorAssembler(inputCols = ['height'],outputCol = 'features')
v_myDF = VectorAssembler.transform(myDF)
v_myDF.take(3)
[Row(_c0=1, height=58, weight=115, features=DenseVector([58.0])),
Row(_c0=2, height=59, weight=117, features=DenseVector([59.0])),
Row(_c0=3, height=60, weight=120, features=DenseVector([60.0]))]
输出结果中的DenseVector的含义为"密集向量"。在Spark中,向量分为密集(Desense)向量和稀疏(Sparse)向量
提取自变量features 和因变量weight
v_myDF = v_myDF.select(['features','weight'])
v_myDF.take(3)
[Row(features=DenseVector([58.0]), weight=115),
Row(features=DenseVector([59.0]), weight=117),
Row(features=DenseVector([60.0]), weight=120)]
训练集测试集切分
train_df = v_myDF
test_df = v_myDF
模型训练
- 用Spark Mlib的LinearRegression()函数进行简单线性回归
from pyspark.ml.regression import LinearRegression
# 形参featuresCol 和label()
myModel = LinearRegression(featuresCol = 'features',labelCol = 'weight')
myResults = myModel.fit(train_df)
print("Coefficients:" + str(myResults.coefficients))
print("Intercept:" + str(myResults.intercept))
Coefficients:[3.4499999999999913]
Intercept:-87.51666666666614
summary = myResults.summary
模型评价
# 查看“残差”
summary.residuals.take(15)
/usr/local/lib/python3.6/site-packages/pyspark/sql/context.py:127: FutureWarning: Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.
FutureWarning
[Row(residuals=2.416666666666657),
Row(residuals=0.9666666666666401),
Row(residuals=0.5166666666666515),
Row(residuals=0.06666666666666288),
Row(residuals=-0.38333333333332575),
Row(residuals=-0.8333333333333144),
Row(residuals=-1.283333333333303),
Row(residuals=-1.7333333333332916),
Row(residuals=-1.1833333333332803),
Row(residuals=-1.633333333333269),
Row(residuals=-1.0833333333332575),
Row(residuals=-0.5333333333332462),
Row(residuals=0.016666666666736774),
Row(residuals=1.5666666666667481),
Row(residuals=3.1166666666667595)]
# 查看“判断系数R方”
summary.r2
0.9910098326857506
#查看“均方残差”
summary.rootMeanSquaredError
1.419702629269787
# 预测
predictions = myResults.transform(test_df)
predictions.show()
+--------+------+------------------+
|features|weight| prediction|
+--------+------+------------------+
| [58.0]| 115|112.58333333333334|
| [59.0]| 117|116.03333333333336|
| [60.0]| 120|119.48333333333335|
| [61.0]| 123|122.93333333333334|
| [62.0]| 126|126.38333333333333|
| [63.0]| 129|129.83333333333331|
| [64.0]| 132| 133.2833333333333|
| [65.0]| 135| 136.7333333333333|
| [66.0]| 139|140.18333333333328|
| [67.0]| 142|143.63333333333327|
| [68.0]| 146|147.08333333333326|
| [69.0]| 150|150.53333333333325|
| [70.0]| 154|153.98333333333326|
| [71.0]| 159|157.43333333333325|
| [72.0]| 164|160.88333333333324|
+--------+------+------------------+
predictions.select('prediction').show()
+------------------+
| prediction|
+------------------+
|112.58333333333334|
|116.03333333333336|
|119.48333333333335|
|122.93333333333334|
|126.38333333333333|
|129.83333333333331|
| 133.2833333333333|
| 136.7333333333333|
|140.18333333333328|
|143.63333333333327|
|147.08333333333326|
|150.53333333333325|
|153.98333333333326|
|157.43333333333325|
|160.88333333333324|
+------------------+
mySpark.stop()