Pyspark 学习

from pyspark.sql import SparkSession
## 设置要连接的Spark主节点URL,"local"表示在本地运行,"local[2]"表示在本地调用2个核心运行
spark = SparkSession.builder.master('local[2]').appName('Basics').getOrCreate()

一、Spark SQL

df = spark.read.csv('appl_stock.csv',inferSchema=True,header=True)
df.show(5)
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df.createOrReplaceTempView('stock') # 创建 Hive 表
## 利用表 stock 进行 Hive 查询
result = spark.sql("SELECT * FROM stock LIMIT 5")
result.show()
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
## 利用表 stock 计算 Close 大于500的数目
spark.sql("SELECT COUNT(Close) FROM stock WHERE Close > 500").show()
+------------+
|count(Close)|
+------------+
|         403|
+------------+
## 利用表 stock 计算 Volume > 120000000 或 Volume < 110000000 的 Open 字段平均值
spark.sql("SELECT AVG(Open) as open_avg FROM stock WHERE Volume > 120000000 OR Volume < 110000000").show()
+------------------+
|          open_avg|
+------------------+
|309.12406365290224|
+------------------+
## 直接利用 .csv 文件进行 Hive 语句操作
spark.sql("SELECT * FROM csv.`appl_stock.csv`").show(5)
+----------+----------+----------+------------------+----------+---------+------------------+
|       _c0|       _c1|       _c2|               _c3|       _c4|      _c5|               _c6|
+----------+----------+----------+------------------+----------+---------+------------------+
|      Date|      Open|      High|               Low|     Close|   Volume|         Adj Close|
|2010-01-04|213.429998|214.499996|212.38000099999996|214.009998|123432400|         27.727039|
|2010-01-05|214.599998|215.589994|        213.249994|214.379993|150476200|27.774976000000002|
|2010-01-06|214.379993|    215.23|        210.750004|210.969995|138040000|27.333178000000004|
|2010-01-07|    211.75|212.000006|        209.050005|    210.58|119282800|          27.28265|
+----------+----------+----------+------------------+----------+---------+------------------+
only showing top 5 rows

二、DataFrame

1、read text file as dataframe

textFile = spark.read.text('textstudy.md')
textFile.printSchema()
root
 |-- value: string (nullable = true)

DataFrame to RDD

# convert to rdd: dataframe is composed of Row
textFile.rdd.map(lambda x: x[0]).collect()
['hello china',
 'hello shanghai',
 'hello meituandianping',
 'hello love',
 'hello future']
# word count
testFile_rdd = textFile.rdd.map(list).map(lambda x: x[0])
words = testFile_rdd.flatMap(lambda line: line.split(" "))
not_empty = words.filter(lambda x: x!='') 
key_values= not_empty.map(lambda word: (word, 1)) 
counts= key_values.reduceByKey(lambda a, b: a + b)
counts.collect()
[('hello', 5),
 ('china', 1),
 ('shanghai', 1),
 ('meituandianping', 1),
 ('love', 1),
 ('future', 1)]

2、read json file as dataframe

df = spark.read.json('people.json')
df.show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
df.printSchema()
root
 |-- age: long (nullable = true)
 |-- name: string (nullable = true)
df.columns
['age', 'name']
df.describe().show()
+-------+------------------+-------+
|summary|               age|   name|
+-------+------------------+-------+
|  count|                 2|      3|
|   mean|              24.5|   null|
| stddev|7.7781745930520225|   null|
|    min|                19|   Andy|
|    max|                30|Michael|
+-------+------------------+-------+
df.summary().show()
+-------+------------------+-------+
|summary|               age|   name|
+-------+------------------+-------+
|  count|                 2|      3|
|   mean|              24.5|   null|
| stddev|7.7781745930520225|   null|
|    min|                19|   Andy|
|    25%|                19|   null|
|    50%|                19|   null|
|    75%|                30|   null|
|    max|                30|Michael|
+-------+------------------+-------+

处理缺失值

# 去除具有空值的行
df.na.drop().show()
+---+------+
|age|  name|
+---+------+
| 30|  Andy|
| 19|Justin|
+---+------+
# 仅保留至少具有一个非空值的行
df.na.drop(thresh=1).show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
# 去除 age 字段具有空值的行
df.na.drop(subset=["age"]).show()
+---+------+
|age|  name|
+---+------+
| 30|  Andy|
| 19|Justin|
+---+------+
# 去除空值的方式为 any
df.na.drop(how='any').show() ## or 'all'
+---+------+
|age|  name|
+---+------+
| 30|  Andy|
| 19|Justin|
+---+------+
# 去除 name 字段具有空值的行
df.na.fill(0, subset=['name']).show()
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
from pyspark.sql.functions import mean
df = df.na.fill(df.select(mean(df['age'])).collect()[0][0], subset=['age'])
df.show()
df.printSchema()
+---+-------+
|age|   name|
+---+-------+
| 24|Michael|
| 30|   Andy|
| 19| Justin|
+---+-------+

root
 |-- age: long (nullable = true)
 |-- name: string (nullable = true)

3、read csv file as dataframe

df = spark.read.csv('appl_stock.csv',inferSchema=True,header=True)
df.show(5)
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows

4、Functions

(1)filter function

df.filter("Close < 500").show(5)
df.filter(df['Close']<500).show(5)
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows

+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df.filter("Close < 500 AND Open > 500").show(5)
df.filter((df['Close']<500) & (df['Open']>500)).show(5)
+-------------------+----------+------------------+------------------+------------------+---------+---------+
|               Date|      Open|              High|               Low|             Close|   Volume|Adj Close|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
|2012-02-15 00:00:00|514.259995|        526.290016|496.88998399999997|        497.669975|376530000|64.477899|
|2013-09-05 00:00:00|500.250008|500.67997699999995|493.63997699999993|495.26997400000005| 59091900|65.977837|
|2013-09-10 00:00:00|506.199997|        507.450012|        489.500015|494.63999900000005|185798900|65.893915|
|2014-01-30 00:00:00|502.539993|506.49997699999994|         496.70002|        499.779984|169625400|66.967353|
+-------------------+----------+------------------+------------------+------------------+---------+---------+

+-------------------+----------+------------------+------------------+------------------+---------+---------+
|               Date|      Open|              High|               Low|             Close|   Volume|Adj Close|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
|2012-02-15 00:00:00|514.259995|        526.290016|496.88998399999997|        497.669975|376530000|64.477899|
|2013-09-05 00:00:00|500.250008|500.67997699999995|493.63997699999993|495.26997400000005| 59091900|65.977837|
|2013-09-10 00:00:00|506.199997|        507.450012|        489.500015|494.63999900000005|185798900|65.893915|
|2014-01-30 00:00:00|502.539993|506.49997699999994|         496.70002|        499.779984|169625400|66.967353|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
df.filter("Close < 500").select(['Date','Open','Close']).show(7)
df.filter(df['Close']<500).select(['Date','Open','Close']).show(7)
+-------------------+------------------+------------------+
|               Date|              Open|             Close|
+-------------------+------------------+------------------+
|2010-01-04 00:00:00|        213.429998|        214.009998|
|2010-01-05 00:00:00|        214.599998|        214.379993|
|2010-01-06 00:00:00|        214.379993|        210.969995|
|2010-01-07 00:00:00|            211.75|            210.58|
|2010-01-08 00:00:00|        210.299994|211.98000499999998|
|2010-01-11 00:00:00|212.79999700000002|210.11000299999998|
|2010-01-12 00:00:00|209.18999499999998|        207.720001|
+-------------------+------------------+------------------+
only showing top 7 rows

+-------------------+------------------+------------------+
|               Date|              Open|             Close|
+-------------------+------------------+------------------+
|2010-01-04 00:00:00|        213.429998|        214.009998|
|2010-01-05 00:00:00|        214.599998|        214.379993|
|2010-01-06 00:00:00|        214.379993|        210.969995|
|2010-01-07 00:00:00|            211.75|            210.58|
|2010-01-08 00:00:00|        210.299994|211.98000499999998|
|2010-01-11 00:00:00|212.79999700000002|210.11000299999998|
|2010-01-12 00:00:00|209.18999499999998|        207.720001|
+-------------------+------------------+------------------+
only showing top 7 rows
df.filter("Low == 197.16").show()
df.filter(df['Low']==197.16).show()
+-------------------+------------------+----------+------+------+---------+---------+
|               Date|              Open|      High|   Low| Close|   Volume|Adj Close|
+-------------------+------------------+----------+------+------+---------+---------+
|2010-01-22 00:00:00|206.78000600000001|207.499996|197.16|197.75|220441900|25.620401|
+-------------------+------------------+----------+------+------+---------+---------+

+-------------------+------------------+----------+------+------+---------+---------+
|               Date|              Open|      High|   Low| Close|   Volume|Adj Close|
+-------------------+------------------+----------+------+------+---------+---------+
|2010-01-22 00:00:00|206.78000600000001|207.499996|197.16|197.75|220441900|25.620401|
+-------------------+------------------+----------+------+------+---------+---------+
select rows by index
## add index col as first column
header = ['index'] + df.columns
new_df = df.rdd.zipWithIndex().map(lambda x: [x[1]] + list(x[0])).toDF(header)
new_df.filter(new_df.index.isin([1,2,4,6,9])).show(2)
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
|index|               Date|      Open|      High|       Low|     Close|   Volume|         Adj Close|
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
|    1|2010-01-05 00:00:00|214.599998|215.589994|213.249994|214.379993|150476200|27.774976000000002|
|    2|2010-01-06 00:00:00|214.379993|    215.23|210.750004|210.969995|138040000|27.333178000000004|
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
only showing top 2 rows

(2)select function

df.select('Low').show(5)
+------------------+
|               Low|
+------------------+
|212.38000099999996|
|        213.249994|
|        210.750004|
|        209.050005|
|209.06000500000002|
+------------------+
only showing top 5 rows

(3)drop function

df.drop('Low').show(5)
+-------------------+----------+----------+------------------+---------+------------------+
|               Date|      Open|      High|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+---------+------------------+
only showing top 5 rows

(4)withColumn function

df_new = df.withColumn('Low_plus',df['Low']+1)
df_new.select("Low_plus", "Low").show(5)
+------------------+------------------+
|          Low_plus|               Low|
+------------------+------------------+
|213.38000099999996|212.38000099999996|
|        214.249994|        213.249994|
|        211.750004|        210.750004|
|        210.050005|        209.050005|
|210.06000500000002|209.06000500000002|
+------------------+------------------+
only showing top 5 rows
df.withColumnRenamed('Low','Low_new').show(5)
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|           Low_new|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows

(5)groupBy function

df.groupBy('Date').mean().show(5) ## mean could be replaced by min, max, sum, count
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
|               Date|         avg(Open)|         avg(High)|  avg(Low)|avg(Close)|avg(Volume)|    avg(Adj Close)|
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
|2012-03-12 00:00:00| 548.9799879999999|        551.999977|547.000023|551.999977| 1.018206E8|         71.516869|
|2012-11-23 00:00:00|        567.169991|        572.000008|562.600006|571.500023|  6.82066E7|         74.700825|
|2013-02-19 00:00:00|461.10000599999995|        462.730003|453.850014|459.990021| 1.089459E8|60.475753000000005|
|2013-10-08 00:00:00|        489.940025|490.64001500000006|480.540024| 480.93998|  7.27293E7|         64.068854|
|2015-05-18 00:00:00|        128.380005|        130.720001|128.360001|130.190002|  5.08829E7|        125.697198|
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
only showing top 5 rows

(6)orderBy function

df.orderBy('Date').show(5)
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|               Date|      Open|      High|               Low|             Close|   Volume|         Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996|        214.009998|123432400|         27.727039|
|2010-01-05 00:00:00|214.599998|215.589994|        213.249994|        214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993|    215.23|        210.750004|        210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00|    211.75|212.000006|        209.050005|            210.58|119282800|          27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700|         27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df.orderBy(df['Date'].desc()).show(5)
df.orderBy('Date',ascending=False).show(5)
+-------------------+----------+----------+----------+----------+--------+------------------+
|               Date|      Open|      High|       Low|     Close|  Volume|         Adj Close|
+-------------------+----------+----------+----------+----------+--------+------------------+
|2016-12-30 00:00:00|116.650002|117.199997|    115.43|    115.82|30586300|         115.32002|
|2016-12-29 00:00:00|116.449997|117.110001|116.400002|116.730003|15039500|        116.226096|
|2016-12-28 00:00:00|117.519997|118.019997|116.199997|116.760002|20905900|116.25596499999999|
|2016-12-27 00:00:00|116.519997|117.800003|116.489998|117.260002|18296900|116.75380600000001|
|2016-12-23 00:00:00|115.589996|116.519997|115.589996|116.519997|14249500|        116.016995|
+-------------------+----------+----------+----------+----------+--------+------------------+
only showing top 5 rows

+-------------------+----------+----------+----------+----------+--------+------------------+
|               Date|      Open|      High|       Low|     Close|  Volume|         Adj Close|
+-------------------+----------+----------+----------+----------+--------+------------------+
|2016-12-30 00:00:00|116.650002|117.199997|    115.43|    115.82|30586300|         115.32002|
|2016-12-29 00:00:00|116.449997|117.110001|116.400002|116.730003|15039500|        116.226096|
|2016-12-28 00:00:00|117.519997|118.019997|116.199997|116.760002|20905900|116.25596499999999|
|2016-12-27 00:00:00|116.519997|117.800003|116.489998|117.260002|18296900|116.75380600000001|
|2016-12-23 00:00:00|115.589996|116.519997|115.589996|116.519997|14249500|        116.016995|
+-------------------+----------+----------+----------+----------+--------+------------------+
only showing top 5 rows

(7)agg function

df.agg({'Volume':'sum'}).show()
+------------+
| sum(Volume)|
+------------+
|166025817100|
+------------+
df.groupBy('Date').agg({'Volume':'mean'}).show(5)
+-------------------+-----------+
|               Date|avg(Volume)|
+-------------------+-----------+
|2012-03-12 00:00:00| 1.018206E8|
|2012-11-23 00:00:00|  6.82066E7|
|2013-02-19 00:00:00| 1.089459E8|
|2013-10-08 00:00:00|  7.27293E7|
|2015-05-18 00:00:00|  5.08829E7|
+-------------------+-----------+
only showing top 5 rows

三、Spark MLlib

spark = SparkSession.builder.appName('test').getOrCreate()

1、回归(Regression)

df = spark.read.csv('cruise_ship_info.csv',inferSchema=True,header=True)
df.show(5)
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows

(1)将标签数据转化为整数索引

from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol="Cruise_line", outputCol="cruise_cat")
indexed = indexer.fit(df).transform(df)
indexed.show(5)
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cruise_cat|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|       1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|       1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|       1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
only showing top 5 rows

(2)将字段组合以对特征进行建模

from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
## VectorAssembler是一个转换器,它可以将给定的多列转换为一个向量列
assembler = VectorAssembler(
  inputCols=['Age',
             'Tonnage',
             'passengers',
             'length',
             'cabins',
             'passenger_density',
             'cruise_cat'],
    outputCol="features")
output = assembler.transform(indexed)
output.select("features", "crew").show(5)
+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
+--------------------+----+
only showing top 5 rows

(3)将数据集划分为训练集和测试集

full_data = output.select("features", "crew")
train_data,test_data = full_data.randomSplit([0.8,0.2])

(4)选择线性回归模型并训练

from pyspark.ml.regression import LinearRegression
lr = LinearRegression(featuresCol = 'features',labelCol='crew',predictionCol='prediction')
lrModel = lr.fit(train_data)
print(lrModel.coefficients)
print(lrModel.intercept)
[-0.017085691500866265,0.0064925570120491225,-0.14616134750393708,0.4009769028859461,0.8720907710851697,0.00012638567124781204,0.04043474402085859]
-0.8703567887087273
trainingSummary = lrModel.summary
print(trainingSummary.rootMeanSquaredError)
print(trainingSummary.r2)
0.9796405605574622
0.9151724396508625
trainingSummary.residuals.show(5)
+--------------------+
|           residuals|
+--------------------+
| -1.3197210112896958|
|  0.2957452235313216|
|   0.648959073658145|
|0.059448228597265285|
| -0.7894144891131782|
+--------------------+
only showing top 5 rows
trainingSummary.predictions.show(5)
+--------------------+-----+------------------+
|            features| crew|        prediction|
+--------------------+-----+------------------+
|[5.0,86.0,21.04,9...|  8.0| 9.319721011289696|
|[5.0,115.0,35.74,...| 12.2|11.904254776468678|
|[5.0,122.0,28.5,1...|  6.7| 6.051040926341855|
|[5.0,133.5,39.59,...|13.13|13.070551771402735|
|[6.0,30.276999999...| 3.55| 4.339414489113178|
+--------------------+-----+------------------+
only showing top 5 rows

(5)评估模型

test_results = lrModel.evaluate(test_data)
print(test_results.rootMeanSquaredError)
print(test_results.meanSquaredError)
print(test_results.r2)
0.7479941744294354
0.5594952849803726
0.9628948277710834
test_results.predictions.show(5)
+--------------------+-----+------------------+
|            features| crew|        prediction|
+--------------------+-----+------------------+
|[4.0,220.0,54.0,1...| 21.0|20.888096986384113|
|[5.0,160.0,36.34,...| 13.6|15.081837739236024|
|[7.0,116.0,31.0,9...| 12.0| 12.70952070366211|
|[8.0,77.499,19.5,...|  9.0| 8.667117440235574|
|[9.0,113.0,26.74,...|12.38|11.360531000252465|
+--------------------+-----+------------------+
only showing top 5 rows

(6)预测模型

predictions = lrModel.transform(test_data.select('features'))
predictions.show(5)
+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[4.0,220.0,54.0,1...|20.888096986384113|
|[5.0,160.0,36.34,...|15.081837739236024|
|[7.0,116.0,31.0,9...| 12.70952070366211|
|[8.0,77.499,19.5,...| 8.667117440235574|
|[9.0,113.0,26.74,...|11.360531000252465|
+--------------------+------------------+
only showing top 5 rows

(7)特征和标签的相关性

from pyspark.sql.functions import corr
df.select(corr('crew','passengers')).show()
+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+
df.select(corr('crew','cabins')).show()
+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

2、分类(Classification)

data = spark.read.csv('customer_churn.csv',inferSchema=True,header=True)
data.printSchema()
root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)
data.select('Names','Age','Total_Purchase','Years','Num_Sites','Location','Company','Churn').show(5)
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
|           Names| Age|Total_Purchase|Years|Num_Sites|            Location|             Company|Churn|
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
|Cameron Williams|42.0|       11066.8| 7.22|      8.0|10265 Elizabeth M...|          Harvey LLC|    1|
|   Kevin Mueller|41.0|      11916.22|  6.5|     11.0|6157 Frank Garden...|          Wilson PLC|    1|
|     Eric Lozano|38.0|      12884.75| 6.67|     12.0|1331 Keith Court ...|Miller, Johnson a...|    1|
|   Phillip White|42.0|       8010.76| 6.71|     10.0|13120 Daniel Moun...|           Smith Inc|    1|
|  Cynthia Norton|37.0|       9191.58| 5.56|      9.0|765 Tricia Row Ka...|          Love-Jones|    1|
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
only showing top 5 rows

(1)连续特征到分类特征

from pyspark.ml.feature import Binarizer, Bucketizer
# 将数值特征转化为二值特征,threshold参数表示决定二值化的阈值
# threshold = 0 for binarizer
binarizer = Binarizer(threshold=5, inputCol='Total_Purchase', outputCol='Total_Purchase_cat')
# 根据阈值列表(分割的参数),将连续变量转换为多项值(连续变量离散化到指定的范围区间)
# provide 5 split points to generate 4 buckets
bucketizer = Bucketizer(splits=[0, 10, 30, 50, 70], inputCol='Age', outputCol='age_cat')
# pipeline stages
from pyspark.ml import Pipeline
stages = [binarizer, bucketizer]
pipeline = Pipeline(stages=stages)
# fit the pipeline model and transform the data
result = pipeline.fit(data).transform(data)
result.select('Names','Age','Total_Purchase','Years','Num_Sites','Company','Churn','Total_Purchase_cat','age_cat').show(5)
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
|           Names| Age|Total_Purchase|Years|Num_Sites|             Company|Churn|Total_Purchase_cat|age_cat|
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
|Cameron Williams|42.0|       11066.8| 7.22|      8.0|          Harvey LLC|    1|               1.0|    2.0|
|   Kevin Mueller|41.0|      11916.22|  6.5|     11.0|          Wilson PLC|    1|               1.0|    2.0|
|     Eric Lozano|38.0|      12884.75| 6.67|     12.0|Miller, Johnson a...|    1|               1.0|    2.0|
|   Phillip White|42.0|       8010.76| 6.71|     10.0|           Smith Inc|    1|               1.0|    2.0|
|  Cynthia Norton|37.0|       9191.58| 5.56|      9.0|          Love-Jones|    1|               1.0|    2.0|
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
only showing top 5 rows

(2)选择列作为模型输入特征

from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites'],outputCol='features')
output = assembler.transform(data)

(3)划分训练集和测试集

final_data = output.select('features','churn')
final_data.show(5)
+--------------------+-----+
|            features|churn|
+--------------------+-----+
|[42.0,11066.8,0.0...|    1|
|[41.0,11916.22,0....|    1|
|[38.0,12884.75,0....|    1|
|[42.0,8010.76,0.0...|    1|
|[37.0,9191.58,0.0...|    1|
+--------------------+-----+
only showing top 5 rows
train_churn,test_churn = final_data.randomSplit([0.8,0.2])

(4)选择模型并训练

方法一: 逻辑回归模型

from pyspark.ml.classification import LogisticRegression
lr_churn = LogisticRegression(featuresCol = 'features',labelCol='churn')
model = lr_churn.fit(train_churn)
training_sum = model.summary
training_sum.predictions.show(5)
+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[22.0,11254.38,1....|  0.0|[4.27126932828609...|[0.98622826225292...|       0.0|
|[25.0,9672.03,0.0...|  0.0|[4.32716495449521...|[0.98696716661289...|       0.0|
|[26.0,8787.39,1.0...|  1.0|[0.50691964789309...|[0.62408409087940...|       0.0|
|[26.0,8939.61,0.0...|  0.0|[5.94537216632240...|[0.99738890778859...|       0.0|
|[27.0,8628.8,1.0,...|  0.0|[5.07194344079783...|[0.99376884822224...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows

(5)模型评估

from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator
# 代入测试集
pred_and_labels = model.evaluate(test_churn)
pred_and_labels.predictions.show(5)
+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[28.0,8670.98,0.0...|    0|[7.28032028440204...|[0.99931150948212...|       0.0|
|[29.0,5900.78,1.0...|    0|[3.80245014943030...|[0.97817110695536...|       0.0|
|[29.0,9378.24,0.0...|    0|[4.42704540541525...|[0.98819136594128...|       0.0|
|[30.0,8874.83,0.0...|    0|[2.92386878493753...|[0.94901382190383...|       0.0|
|[30.0,10744.14,1....|    1|[1.56910959319232...|[0.82765663668636...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows
churn_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction',labelCol='churn')
churn_eval_multi = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='churn',metricName='accuracy')
auc = churn_eval_multi.evaluate(pred_and_labels.predictions)
auc
0.918918918918919

(6)模型预测

churn_test = model.transform(test_churn.select('features'))
churn_test.show(5)
+--------------------+--------------------+--------------------+----------+
|            features|       rawPrediction|         probability|prediction|
+--------------------+--------------------+--------------------+----------+
|[28.0,8670.98,0.0...|[7.28032028440204...|[0.99931150948212...|       0.0|
|[29.0,5900.78,1.0...|[3.80245014943030...|[0.97817110695536...|       0.0|
|[29.0,9378.24,0.0...|[4.42704540541525...|[0.98819136594128...|       0.0|
|[30.0,8874.83,0.0...|[2.92386878493753...|[0.94901382190383...|       0.0|
|[30.0,10744.14,1....|[1.56910959319232...|[0.82765663668636...|       0.0|
+--------------------+--------------------+--------------------+----------+
only showing top 5 rows

方法二:决策树模型

from pyspark.ml.classification import RandomForestClassifier,DecisionTreeClassifier
dtc = DecisionTreeClassifier(labelCol='churn',featuresCol='features')
dtc_model = dtc.fit(train_churn)
print(dtc_model.featureImportances)
(5,[0,1,3,4],[0.09646621280342624,0.09365722250595962,0.14583722780378533,0.6640393368868287])
predictions = dtc_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.9081081081081082

方法三:随机森林模型

rfc = RandomForestClassifier(labelCol="churn", featuresCol="features", numTrees=20)
rfc_model = rfc.fit(train_churn)
print(rfc_model.featureImportances)
(5,[0,1,2,3,4],[0.09112246833941745,0.07289697555412486,0.00807666535024141,0.1859341040975848,0.6419697866586315])
predictions = rfc_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.9081081081081082

方法四:梯度提升树模型

from pyspark.ml.classification import GBTClassifier
gbt = GBTClassifier(labelCol="churn", featuresCol="features", maxIter=20)
gbt_model = gbt.fit(train_churn)
predictions = gbt_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.9027027027027027

方法五:逻辑回归模型➕交叉验证

from pyspark.ml.classification import LogisticRegression
blor = LogisticRegression(featuresCol='features', labelCol='churn', family='binomial')
from pyspark.ml.tuning import ParamGridBuilder
param_grid = ParamGridBuilder().\
    addGrid(blor.regParam, [0, 0.5, 1, 2]).\
    addGrid(blor.elasticNetParam, [0, 0.5, 1]).\
    build()
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator()
from pyspark.ml.tuning import CrossValidator
cv = CrossValidator(estimator=blor, estimatorParamMaps=param_grid, evaluator=churn_eval_multi, numFolds=4)
cvModel = cv.fit(train_churn)
cvModel.bestModel.intercept
-18.032119312923868
cvModel.bestModel.coefficients
DenseVector([0.0523, 0.0, 0.4303, 0.5307, 1.1368])
cvModel.bestModel._java_obj.getRegParam()
0.0
cvModel.bestModel._java_obj.getElasticNetParam()
0.0
predictions = cvModel.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.918918918918919

补充:混淆矩阵

label_pred_train = predictions.select('churn', 'prediction')
label_pred_train.rdd.zipWithIndex().countByKey()
defaultdict(int,
            {Row(churn=0, prediction=0.0): 151,
             Row(churn=1, prediction=0.0): 9,
             Row(churn=1, prediction=1.0): 19,
             Row(churn=0, prediction=1.0): 6})

3、聚类(Clustering)

data = spark.read.csv("hack_data.csv",header=True,inferSchema=True)
data.printSchema()
root
 |-- Session_Connection_Time: double (nullable = true)
 |-- Bytes Transferred: double (nullable = true)
 |-- Kali_Trace_Used: integer (nullable = true)
 |-- Servers_Corrupted: double (nullable = true)
 |-- Pages_Corrupted: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- WPM_Typing_Speed: double (nullable = true)
data.select('Session_Connection_Time','Bytes Transferred','Kali_Trace_Used','Servers_Corrupted','Pages_Corrupted','WPM_Typing_Speed').show(5)
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|WPM_Typing_Speed|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
|                    8.0|           391.09|              1|             2.96|            7.0|           72.37|
|                   20.0|           720.99|              0|             3.04|            9.0|           69.08|
|                   31.0|           356.32|              1|             3.71|            8.0|           70.58|
|                    2.0|           228.08|              1|             2.48|            8.0|            70.8|
|                   20.0|            408.5|              0|             3.57|            8.0|           71.28|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
only showing top 5 rows
data.columns
['Session_Connection_Time',
 'Bytes Transferred',
 'Kali_Trace_Used',
 'Servers_Corrupted',
 'Pages_Corrupted',
 'Location',
 'WPM_Typing_Speed']

(1)选择列作为模型输入特征

from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
feat_cols = ['Session_Connection_Time', 'Bytes Transferred', 'Kali_Trace_Used',
             'Servers_Corrupted', 'Pages_Corrupted','WPM_Typing_Speed']
vec_assembler = VectorAssembler(inputCols = feat_cols, outputCol='features')
final_data = vec_assembler.transform(data)
final_data.select('features').head(1)[0]
Row(features=DenseVector([8.0, 391.09, 1.0, 2.96, 7.0, 72.37]))

(2)特征标准化

from pyspark.ml.feature import StandardScaler
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False)
cluster_final_data = scaler.fit(final_data).transform(final_data)
cluster_final_data.select("scaledFeatures").show(5)
+--------------------+
|      scaledFeatures|
+--------------------+
|[0.56785108466505...|
|[1.41962771166263...|
|[2.20042295307707...|
|[0.14196277116626...|
|[1.41962771166263...|
+--------------------+
only showing top 5 rows
cluster_final_data.select("scaledFeatures").head(1)[0]
Row(scaledFeatures=DenseVector([0.5679, 1.3658, 1.9976, 1.2859, 2.2849, 5.3963]))

(3)K-Means 聚类

from pyspark.ml.clustering import KMeans
model = KMeans(featuresCol='scaledFeatures',k=3)
model = model.fit(cluster_final_data)
model.computeCost(cluster_final_data)
434.1492898715845
model.clusterCenters()
[array([1.30217042, 1.25830099, 0.        , 1.35793211, 2.57251009,
        5.24230473]),
 array([2.99991988, 2.92319035, 1.05261534, 3.20390443, 4.51321315,
        3.28474   ]),
 array([1.21780112, 1.37901802, 1.99757683, 1.37198977, 2.55237797,
        5.29152222])]

(4)模型预测

model.transform(cluster_final_data).groupBy('prediction').count().show()
+----------+-----+
|prediction|count|
+----------+-----+
|         1|  167|
|         2|   83|
|         0|   84|
+----------+-----+
model.transform(cluster_final_data).show(5)
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|            Location|WPM_Typing_Speed|            features|      scaledFeatures|prediction|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
|                    8.0|           391.09|              1|             2.96|            7.0|            Slovenia|           72.37|[8.0,391.09,1.0,2...|[0.56785108466505...|         2|
|                   20.0|           720.99|              0|             3.04|            9.0|British Virgin Is...|           69.08|[20.0,720.99,0.0,...|[1.41962771166263...|         0|
|                   31.0|           356.32|              1|             3.71|            8.0|             Tokelau|           70.58|[31.0,356.32,1.0,...|[2.20042295307707...|         2|
|                    2.0|           228.08|              1|             2.48|            8.0|             Bolivia|            70.8|[2.0,228.08,1.0,2...|[0.14196277116626...|         2|
|                   20.0|            408.5|              0|             3.57|            8.0|                Iraq|           71.28|[20.0,408.5,0.0,3...|[1.41962771166263...|         0|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
only showing top 5 rows

4、基于 TF-IDF 算法的文本挖掘

data = spark.read.csv("SMSSpamCollection",inferSchema=True,sep='\t')
data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')
data.show(5)
+-----+--------------------+
|class|                text|
+-----+--------------------+
|  ham|Go until jurong p...|
|  ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
|  ham|U dun say so earl...|
|  ham|Nah I don't think...|
+-----+--------------------+
only showing top 5 rows

(1)数据预处理

from pyspark.sql.functions import length
# compute length of each text
data = data.withColumn('length',length(data['text']))
(1.1)分词
from pyspark.ml.feature import Tokenizer,StopWordsRemover,CountVectorizer,IDF,StringIndexer
tokenizer = Tokenizer(inputCol="text", outputCol="stop_tokens")
(1.2)去除停用词
# stopremove = StopWordsRemover(inputCol='token_text',outputCol='stop_tokens')
(1.3)计算词频
count_vec = CountVectorizer(inputCol='stop_tokens',outputCol='c_vec')
(1.4)计算逆文本频率
idf = IDF(inputCol="c_vec", outputCol="tf_idf")
(1.5)将类标签由字符串映射到索引
ham_spam_to_num = StringIndexer(inputCol='class',outputCol='label')

(2)将列转化为模型输入特征

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vector
clean_up = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')

(3)构建模型

from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()

(4)构建 pipeline

from pyspark.ml import Pipeline
data_prep_pipe = Pipeline(stages=[ham_spam_to_num,tokenizer,count_vec,idf,clean_up])
cleaner = data_prep_pipe.fit(data)
clean_data = cleaner.transform(data)

(5)划分训练集和测试集

full_data = clean_data.select(['label','features'])
(train_data,test_data) = full_data.randomSplit([0.8,0.2])

(6)模型训练

model = nb.fit(train_data)
test_results = model.transform(test_data)
test_results.show(5)
+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(13588,[0,1,2,3,4...|[-1350.8171609962...|[1.0,1.0733004084...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-3071.3460250107...|[1.0,1.4929420982...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-1454.8011163433...|[1.0,1.8318571536...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-1169.5412775216...|[1.0,5.0678369468...|       0.0|
|  0.0|(13588,[0,1,2,3,5...|[-1769.7764271667...|[1.0,2.5959352248...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows

(7)模型评估

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting spam was: {}".format(acc))
Accuracy of model at predicting spam was: 0.9416633505993651
spark.stop()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值