from pyspark. sql import SparkSession
spark = SparkSession. builder. master( 'local[2]' ) . appName( 'Basics' ) . getOrCreate( )
一、Spark SQL
df = spark. read. csv( 'appl_stock.csv' , inferSchema= True , header= True )
df. show( 5 )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df. createOrReplaceTempView( 'stock' )
result = spark. sql( "SELECT * FROM stock LIMIT 5" )
result. show( )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
spark. sql( "SELECT COUNT(Close) FROM stock WHERE Close > 500" ) . show( )
+------------+
|count(Close)|
+------------+
| 403|
+------------+
spark. sql( "SELECT AVG(Open) as open_avg FROM stock WHERE Volume > 120000000 OR Volume < 110000000" ) . show( )
+------------------+
| open_avg|
+------------------+
|309.12406365290224|
+------------------+
spark. sql( "SELECT * FROM csv.`appl_stock.csv`" ) . show( 5 )
+----------+----------+----------+------------------+----------+---------+------------------+
| _c0| _c1| _c2| _c3| _c4| _c5| _c6|
+----------+----------+----------+------------------+----------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
|2010-01-04|213.429998|214.499996|212.38000099999996|214.009998|123432400| 27.727039|
|2010-01-05|214.599998|215.589994| 213.249994|214.379993|150476200|27.774976000000002|
|2010-01-06|214.379993| 215.23| 210.750004|210.969995|138040000|27.333178000000004|
|2010-01-07| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
+----------+----------+----------+------------------+----------+---------+------------------+
only showing top 5 rows
二、DataFrame
1、read text file as dataframe
textFile = spark. read. text( 'textstudy.md' )
textFile. printSchema( )
root
|-- value: string (nullable = true)
DataFrame to RDD
textFile. rdd. map ( lambda x: x[ 0 ] ) . collect( )
['hello china',
'hello shanghai',
'hello meituandianping',
'hello love',
'hello future']
testFile_rdd = textFile. rdd. map ( list ) . map ( lambda x: x[ 0 ] )
words = testFile_rdd. flatMap( lambda line: line. split( " " ) )
not_empty = words. filter ( lambda x: x!= '' )
key_values= not_empty. map ( lambda word: ( word, 1 ) )
counts= key_values. reduceByKey( lambda a, b: a + b)
counts. collect( )
[('hello', 5),
('china', 1),
('shanghai', 1),
('meituandianping', 1),
('love', 1),
('future', 1)]
2、read json file as dataframe
df = spark. read. json( 'people.json' )
df. show( )
+----+-------+
| age| name|
+----+-------+
|null|Michael|
| 30| Andy|
| 19| Justin|
+----+-------+
df. printSchema( )
root
|-- age: long (nullable = true)
|-- name: string (nullable = true)
df. columns
['age', 'name']
df. describe( ) . show( )
+-------+------------------+-------+
|summary| age| name|
+-------+------------------+-------+
| count| 2| 3|
| mean| 24.5| null|
| stddev|7.7781745930520225| null|
| min| 19| Andy|
| max| 30|Michael|
+-------+------------------+-------+
df. summary( ) . show( )
+-------+------------------+-------+
|summary| age| name|
+-------+------------------+-------+
| count| 2| 3|
| mean| 24.5| null|
| stddev|7.7781745930520225| null|
| min| 19| Andy|
| 25%| 19| null|
| 50%| 19| null|
| 75%| 30| null|
| max| 30|Michael|
+-------+------------------+-------+
处理缺失值
df. na. drop( ) . show( )
+---+------+
|age| name|
+---+------+
| 30| Andy|
| 19|Justin|
+---+------+
df. na. drop( thresh= 1 ) . show( )
+----+-------+
| age| name|
+----+-------+
|null|Michael|
| 30| Andy|
| 19| Justin|
+----+-------+
df. na. drop( subset= [ "age" ] ) . show( )
+---+------+
|age| name|
+---+------+
| 30| Andy|
| 19|Justin|
+---+------+
df. na. drop( how= 'any' ) . show( )
+---+------+
|age| name|
+---+------+
| 30| Andy|
| 19|Justin|
+---+------+
df. na. fill( 0 , subset= [ 'name' ] ) . show( )
+----+-------+
| age| name|
+----+-------+
|null|Michael|
| 30| Andy|
| 19| Justin|
+----+-------+
from pyspark. sql. functions import mean
df = df. na. fill( df. select( mean( df[ 'age' ] ) ) . collect( ) [ 0 ] [ 0 ] , subset= [ 'age' ] )
df. show( )
df. printSchema( )
+---+-------+
|age| name|
+---+-------+
| 24|Michael|
| 30| Andy|
| 19| Justin|
+---+-------+
root
|-- age: long (nullable = true)
|-- name: string (nullable = true)
3、read csv file as dataframe
df = spark. read. csv( 'appl_stock.csv' , inferSchema= True , header= True )
df. show( 5 )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
4、Functions
(1)filter function
df. filter ( "Close < 500" ) . show( 5 )
df. filter ( df[ 'Close' ] < 500 ) . show( 5 )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df. filter ( "Close < 500 AND Open > 500" ) . show( 5 )
df. filter ( ( df[ 'Close' ] < 500 ) & ( df[ 'Open' ] > 500 ) ) . show( 5 )
+-------------------+----------+------------------+------------------+------------------+---------+---------+
| Date| Open| High| Low| Close| Volume|Adj Close|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
|2012-02-15 00:00:00|514.259995| 526.290016|496.88998399999997| 497.669975|376530000|64.477899|
|2013-09-05 00:00:00|500.250008|500.67997699999995|493.63997699999993|495.26997400000005| 59091900|65.977837|
|2013-09-10 00:00:00|506.199997| 507.450012| 489.500015|494.63999900000005|185798900|65.893915|
|2014-01-30 00:00:00|502.539993|506.49997699999994| 496.70002| 499.779984|169625400|66.967353|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
+-------------------+----------+------------------+------------------+------------------+---------+---------+
| Date| Open| High| Low| Close| Volume|Adj Close|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
|2012-02-15 00:00:00|514.259995| 526.290016|496.88998399999997| 497.669975|376530000|64.477899|
|2013-09-05 00:00:00|500.250008|500.67997699999995|493.63997699999993|495.26997400000005| 59091900|65.977837|
|2013-09-10 00:00:00|506.199997| 507.450012| 489.500015|494.63999900000005|185798900|65.893915|
|2014-01-30 00:00:00|502.539993|506.49997699999994| 496.70002| 499.779984|169625400|66.967353|
+-------------------+----------+------------------+------------------+------------------+---------+---------+
df. filter ( "Close < 500" ) . select( [ 'Date' , 'Open' , 'Close' ] ) . show( 7 )
df. filter ( df[ 'Close' ] < 500 ) . select( [ 'Date' , 'Open' , 'Close' ] ) . show( 7 )
+-------------------+------------------+------------------+
| Date| Open| Close|
+-------------------+------------------+------------------+
|2010-01-04 00:00:00| 213.429998| 214.009998|
|2010-01-05 00:00:00| 214.599998| 214.379993|
|2010-01-06 00:00:00| 214.379993| 210.969995|
|2010-01-07 00:00:00| 211.75| 210.58|
|2010-01-08 00:00:00| 210.299994|211.98000499999998|
|2010-01-11 00:00:00|212.79999700000002|210.11000299999998|
|2010-01-12 00:00:00|209.18999499999998| 207.720001|
+-------------------+------------------+------------------+
only showing top 7 rows
+-------------------+------------------+------------------+
| Date| Open| Close|
+-------------------+------------------+------------------+
|2010-01-04 00:00:00| 213.429998| 214.009998|
|2010-01-05 00:00:00| 214.599998| 214.379993|
|2010-01-06 00:00:00| 214.379993| 210.969995|
|2010-01-07 00:00:00| 211.75| 210.58|
|2010-01-08 00:00:00| 210.299994|211.98000499999998|
|2010-01-11 00:00:00|212.79999700000002|210.11000299999998|
|2010-01-12 00:00:00|209.18999499999998| 207.720001|
+-------------------+------------------+------------------+
only showing top 7 rows
df. filter ( "Low == 197.16" ) . show( )
df. filter ( df[ 'Low' ] == 197.16 ) . show( )
+-------------------+------------------+----------+------+------+---------+---------+
| Date| Open| High| Low| Close| Volume|Adj Close|
+-------------------+------------------+----------+------+------+---------+---------+
|2010-01-22 00:00:00|206.78000600000001|207.499996|197.16|197.75|220441900|25.620401|
+-------------------+------------------+----------+------+------+---------+---------+
+-------------------+------------------+----------+------+------+---------+---------+
| Date| Open| High| Low| Close| Volume|Adj Close|
+-------------------+------------------+----------+------+------+---------+---------+
|2010-01-22 00:00:00|206.78000600000001|207.499996|197.16|197.75|220441900|25.620401|
+-------------------+------------------+----------+------+------+---------+---------+
select rows by index
header = [ 'index' ] + df. columns
new_df = df. rdd. zipWithIndex( ) . map ( lambda x: [ x[ 1 ] ] + list ( x[ 0 ] ) ) . toDF( header)
new_df. filter ( new_df. index. isin( [ 1 , 2 , 4 , 6 , 9 ] ) ) . show( 2 )
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
|index| Date| Open| High| Low| Close| Volume| Adj Close|
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
| 1|2010-01-05 00:00:00|214.599998|215.589994|213.249994|214.379993|150476200|27.774976000000002|
| 2|2010-01-06 00:00:00|214.379993| 215.23|210.750004|210.969995|138040000|27.333178000000004|
+-----+-------------------+----------+----------+----------+----------+---------+------------------+
only showing top 2 rows
(2)select function
df. select( 'Low' ) . show( 5 )
+------------------+
| Low|
+------------------+
|212.38000099999996|
| 213.249994|
| 210.750004|
| 209.050005|
|209.06000500000002|
+------------------+
only showing top 5 rows
(3)drop function
df. drop( 'Low' ) . show( 5 )
+-------------------+----------+----------+------------------+---------+------------------+
| Date| Open| High| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+---------+------------------+
only showing top 5 rows
(4)withColumn function
df_new = df. withColumn( 'Low_plus' , df[ 'Low' ] + 1 )
df_new. select( "Low_plus" , "Low" ) . show( 5 )
+------------------+------------------+
| Low_plus| Low|
+------------------+------------------+
|213.38000099999996|212.38000099999996|
| 214.249994| 213.249994|
| 211.750004| 210.750004|
| 210.050005| 209.050005|
|210.06000500000002|209.06000500000002|
+------------------+------------------+
only showing top 5 rows
df. withColumnRenamed( 'Low' , 'Low_new' ) . show( 5 )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low_new| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
(5)groupBy function
df. groupBy( 'Date' ) . mean( ) . show( 5 )
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
| Date| avg(Open)| avg(High)| avg(Low)|avg(Close)|avg(Volume)| avg(Adj Close)|
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
|2012-03-12 00:00:00| 548.9799879999999| 551.999977|547.000023|551.999977| 1.018206E8| 71.516869|
|2012-11-23 00:00:00| 567.169991| 572.000008|562.600006|571.500023| 6.82066E7| 74.700825|
|2013-02-19 00:00:00|461.10000599999995| 462.730003|453.850014|459.990021| 1.089459E8|60.475753000000005|
|2013-10-08 00:00:00| 489.940025|490.64001500000006|480.540024| 480.93998| 7.27293E7| 64.068854|
|2015-05-18 00:00:00| 128.380005| 130.720001|128.360001|130.190002| 5.08829E7| 125.697198|
+-------------------+------------------+------------------+----------+----------+-----------+------------------+
only showing top 5 rows
(6)orderBy function
df. orderBy( 'Date' ) . show( 5 )
+-------------------+----------+----------+------------------+------------------+---------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
|2010-01-04 00:00:00|213.429998|214.499996|212.38000099999996| 214.009998|123432400| 27.727039|
|2010-01-05 00:00:00|214.599998|215.589994| 213.249994| 214.379993|150476200|27.774976000000002|
|2010-01-06 00:00:00|214.379993| 215.23| 210.750004| 210.969995|138040000|27.333178000000004|
|2010-01-07 00:00:00| 211.75|212.000006| 209.050005| 210.58|119282800| 27.28265|
|2010-01-08 00:00:00|210.299994|212.000006|209.06000500000002|211.98000499999998|111902700| 27.464034|
+-------------------+----------+----------+------------------+------------------+---------+------------------+
only showing top 5 rows
df. orderBy( df[ 'Date' ] . desc( ) ) . show( 5 )
df. orderBy( 'Date' , ascending= False ) . show( 5 )
+-------------------+----------+----------+----------+----------+--------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+----------+----------+--------+------------------+
|2016-12-30 00:00:00|116.650002|117.199997| 115.43| 115.82|30586300| 115.32002|
|2016-12-29 00:00:00|116.449997|117.110001|116.400002|116.730003|15039500| 116.226096|
|2016-12-28 00:00:00|117.519997|118.019997|116.199997|116.760002|20905900|116.25596499999999|
|2016-12-27 00:00:00|116.519997|117.800003|116.489998|117.260002|18296900|116.75380600000001|
|2016-12-23 00:00:00|115.589996|116.519997|115.589996|116.519997|14249500| 116.016995|
+-------------------+----------+----------+----------+----------+--------+------------------+
only showing top 5 rows
+-------------------+----------+----------+----------+----------+--------+------------------+
| Date| Open| High| Low| Close| Volume| Adj Close|
+-------------------+----------+----------+----------+----------+--------+------------------+
|2016-12-30 00:00:00|116.650002|117.199997| 115.43| 115.82|30586300| 115.32002|
|2016-12-29 00:00:00|116.449997|117.110001|116.400002|116.730003|15039500| 116.226096|
|2016-12-28 00:00:00|117.519997|118.019997|116.199997|116.760002|20905900|116.25596499999999|
|2016-12-27 00:00:00|116.519997|117.800003|116.489998|117.260002|18296900|116.75380600000001|
|2016-12-23 00:00:00|115.589996|116.519997|115.589996|116.519997|14249500| 116.016995|
+-------------------+----------+----------+----------+----------+--------+------------------+
only showing top 5 rows
(7)agg function
df. agg( { 'Volume' : 'sum' } ) . show( )
+------------+
| sum(Volume)|
+------------+
|166025817100|
+------------+
df. groupBy( 'Date' ) . agg( { 'Volume' : 'mean' } ) . show( 5 )
+-------------------+-----------+
| Date|avg(Volume)|
+-------------------+-----------+
|2012-03-12 00:00:00| 1.018206E8|
|2012-11-23 00:00:00| 6.82066E7|
|2013-02-19 00:00:00| 1.089459E8|
|2013-10-08 00:00:00| 7.27293E7|
|2015-05-18 00:00:00| 5.08829E7|
+-------------------+-----------+
only showing top 5 rows
三、Spark MLlib
spark = SparkSession. builder. appName( 'test' ) . getOrCreate( )
1、回归(Regression)
df = spark. read. csv( 'cruise_ship_info.csv' , inferSchema= True , header= True )
df. show( 5 )
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
| Ship_name|Cruise_line|Age| Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
| Journey| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55|
| Quest| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55|
|Celebration| Carnival| 26| 47.262| 14.86| 7.22| 7.43| 31.8| 6.7|
| Conquest| Carnival| 11| 110.0| 29.74| 9.53| 14.88| 36.99|19.1|
| Destiny| Carnival| 17| 101.353| 26.42| 8.92| 13.21| 38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows
(1)将标签数据转化为整数索引
from pyspark. ml. feature import StringIndexer
indexer = StringIndexer( inputCol= "Cruise_line" , outputCol= "cruise_cat" )
indexed = indexer. fit( df) . transform( df)
indexed. show( 5 )
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
| Ship_name|Cruise_line|Age| Tonnage|passengers|length|cabins|passenger_density|crew|cruise_cat|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
| Journey| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55| 16.0|
| Quest| Azamara| 6|30.276999999999997| 6.94| 5.94| 3.55| 42.64|3.55| 16.0|
|Celebration| Carnival| 26| 47.262| 14.86| 7.22| 7.43| 31.8| 6.7| 1.0|
| Conquest| Carnival| 11| 110.0| 29.74| 9.53| 14.88| 36.99|19.1| 1.0|
| Destiny| Carnival| 17| 101.353| 26.42| 8.92| 13.21| 38.36|10.0| 1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
only showing top 5 rows
(2)将字段组合以对特征进行建模
from pyspark. ml. linalg import Vectors
from pyspark. ml. feature import VectorAssembler
assembler = VectorAssembler(
inputCols= [ 'Age' ,
'Tonnage' ,
'passengers' ,
'length' ,
'cabins' ,
'passenger_density' ,
'cruise_cat' ] ,
outputCol= "features" )
output = assembler. transform( indexed)
output. select( "features" , "crew" ) . show( 5 )
+--------------------+----+
| features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
+--------------------+----+
only showing top 5 rows
(3)将数据集划分为训练集和测试集
full_data = output. select( "features" , "crew" )
train_data, test_data = full_data. randomSplit( [ 0.8 , 0.2 ] )
(4)选择线性回归模型并训练
from pyspark. ml. regression import LinearRegression
lr = LinearRegression( featuresCol = 'features' , labelCol= 'crew' , predictionCol= 'prediction' )
lrModel = lr. fit( train_data)
print ( lrModel. coefficients)
print ( lrModel. intercept)
[-0.017085691500866265,0.0064925570120491225,-0.14616134750393708,0.4009769028859461,0.8720907710851697,0.00012638567124781204,0.04043474402085859]
-0.8703567887087273
trainingSummary = lrModel. summary
print ( trainingSummary. rootMeanSquaredError)
print ( trainingSummary. r2)
0.9796405605574622
0.9151724396508625
trainingSummary. residuals. show( 5 )
+--------------------+
| residuals|
+--------------------+
| -1.3197210112896958|
| 0.2957452235313216|
| 0.648959073658145|
|0.059448228597265285|
| -0.7894144891131782|
+--------------------+
only showing top 5 rows
trainingSummary. predictions. show( 5 )
+--------------------+-----+------------------+
| features| crew| prediction|
+--------------------+-----+------------------+
|[5.0,86.0,21.04,9...| 8.0| 9.319721011289696|
|[5.0,115.0,35.74,...| 12.2|11.904254776468678|
|[5.0,122.0,28.5,1...| 6.7| 6.051040926341855|
|[5.0,133.5,39.59,...|13.13|13.070551771402735|
|[6.0,30.276999999...| 3.55| 4.339414489113178|
+--------------------+-----+------------------+
only showing top 5 rows
(5)评估模型
test_results = lrModel. evaluate( test_data)
print ( test_results. rootMeanSquaredError)
print ( test_results. meanSquaredError)
print ( test_results. r2)
0.7479941744294354
0.5594952849803726
0.9628948277710834
test_results. predictions. show( 5 )
+--------------------+-----+------------------+
| features| crew| prediction|
+--------------------+-----+------------------+
|[4.0,220.0,54.0,1...| 21.0|20.888096986384113|
|[5.0,160.0,36.34,...| 13.6|15.081837739236024|
|[7.0,116.0,31.0,9...| 12.0| 12.70952070366211|
|[8.0,77.499,19.5,...| 9.0| 8.667117440235574|
|[9.0,113.0,26.74,...|12.38|11.360531000252465|
+--------------------+-----+------------------+
only showing top 5 rows
(6)预测模型
predictions = lrModel. transform( test_data. select( 'features' ) )
predictions. show( 5 )
+--------------------+------------------+
| features| prediction|
+--------------------+------------------+
|[4.0,220.0,54.0,1...|20.888096986384113|
|[5.0,160.0,36.34,...|15.081837739236024|
|[7.0,116.0,31.0,9...| 12.70952070366211|
|[8.0,77.499,19.5,...| 8.667117440235574|
|[9.0,113.0,26.74,...|11.360531000252465|
+--------------------+------------------+
only showing top 5 rows
(7)特征和标签的相关性
from pyspark. sql. functions import corr
df. select( corr( 'crew' , 'passengers' ) ) . show( )
+----------------------+
|corr(crew, passengers)|
+----------------------+
| 0.9152341306065384|
+----------------------+
df. select( corr( 'crew' , 'cabins' ) ) . show( )
+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+
2、分类(Classification)
data = spark. read. csv( 'customer_churn.csv' , inferSchema= True , header= True )
data. printSchema( )
root
|-- Names: string (nullable = true)
|-- Age: double (nullable = true)
|-- Total_Purchase: double (nullable = true)
|-- Account_Manager: integer (nullable = true)
|-- Years: double (nullable = true)
|-- Num_Sites: double (nullable = true)
|-- Onboard_date: timestamp (nullable = true)
|-- Location: string (nullable = true)
|-- Company: string (nullable = true)
|-- Churn: integer (nullable = true)
data. select( 'Names' , 'Age' , 'Total_Purchase' , 'Years' , 'Num_Sites' , 'Location' , 'Company' , 'Churn' ) . show( 5 )
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
| Names| Age|Total_Purchase|Years|Num_Sites| Location| Company|Churn|
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
|Cameron Williams|42.0| 11066.8| 7.22| 8.0|10265 Elizabeth M...| Harvey LLC| 1|
| Kevin Mueller|41.0| 11916.22| 6.5| 11.0|6157 Frank Garden...| Wilson PLC| 1|
| Eric Lozano|38.0| 12884.75| 6.67| 12.0|1331 Keith Court ...|Miller, Johnson a...| 1|
| Phillip White|42.0| 8010.76| 6.71| 10.0|13120 Daniel Moun...| Smith Inc| 1|
| Cynthia Norton|37.0| 9191.58| 5.56| 9.0|765 Tricia Row Ka...| Love-Jones| 1|
+----------------+----+--------------+-----+---------+--------------------+--------------------+-----+
only showing top 5 rows
(1)连续特征到分类特征
from pyspark. ml. feature import Binarizer, Bucketizer
binarizer = Binarizer( threshold= 5 , inputCol= 'Total_Purchase' , outputCol= 'Total_Purchase_cat' )
bucketizer = Bucketizer( splits= [ 0 , 10 , 30 , 50 , 70 ] , inputCol= 'Age' , outputCol= 'age_cat' )
from pyspark. ml import Pipeline
stages = [ binarizer, bucketizer]
pipeline = Pipeline( stages= stages)
result = pipeline. fit( data) . transform( data)
result. select( 'Names' , 'Age' , 'Total_Purchase' , 'Years' , 'Num_Sites' , 'Company' , 'Churn' , 'Total_Purchase_cat' , 'age_cat' ) . show( 5 )
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
| Names| Age|Total_Purchase|Years|Num_Sites| Company|Churn|Total_Purchase_cat|age_cat|
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
|Cameron Williams|42.0| 11066.8| 7.22| 8.0| Harvey LLC| 1| 1.0| 2.0|
| Kevin Mueller|41.0| 11916.22| 6.5| 11.0| Wilson PLC| 1| 1.0| 2.0|
| Eric Lozano|38.0| 12884.75| 6.67| 12.0|Miller, Johnson a...| 1| 1.0| 2.0|
| Phillip White|42.0| 8010.76| 6.71| 10.0| Smith Inc| 1| 1.0| 2.0|
| Cynthia Norton|37.0| 9191.58| 5.56| 9.0| Love-Jones| 1| 1.0| 2.0|
+----------------+----+--------------+-----+---------+--------------------+-----+------------------+-------+
only showing top 5 rows
(2)选择列作为模型输入特征
from pyspark. ml. feature import VectorAssembler
assembler = VectorAssembler( inputCols= [ 'Age' ,
'Total_Purchase' ,
'Account_Manager' ,
'Years' ,
'Num_Sites' ] , outputCol= 'features' )
output = assembler. transform( data)
(3)划分训练集和测试集
final_data = output. select( 'features' , 'churn' )
final_data. show( 5 )
+--------------------+-----+
| features|churn|
+--------------------+-----+
|[42.0,11066.8,0.0...| 1|
|[41.0,11916.22,0....| 1|
|[38.0,12884.75,0....| 1|
|[42.0,8010.76,0.0...| 1|
|[37.0,9191.58,0.0...| 1|
+--------------------+-----+
only showing top 5 rows
train_churn, test_churn = final_data. randomSplit( [ 0.8 , 0.2 ] )
(4)选择模型并训练
方法一: 逻辑回归模型
from pyspark. ml. classification import LogisticRegression
lr_churn = LogisticRegression( featuresCol = 'features' , labelCol= 'churn' )
model = lr_churn. fit( train_churn)
training_sum = model. summary
training_sum. predictions. show( 5 )
+--------------------+-----+--------------------+--------------------+----------+
| features|churn| rawPrediction| probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[22.0,11254.38,1....| 0.0|[4.27126932828609...|[0.98622826225292...| 0.0|
|[25.0,9672.03,0.0...| 0.0|[4.32716495449521...|[0.98696716661289...| 0.0|
|[26.0,8787.39,1.0...| 1.0|[0.50691964789309...|[0.62408409087940...| 0.0|
|[26.0,8939.61,0.0...| 0.0|[5.94537216632240...|[0.99738890778859...| 0.0|
|[27.0,8628.8,1.0,...| 0.0|[5.07194344079783...|[0.99376884822224...| 0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows
(5)模型评估
from pyspark. ml. evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
pred_and_labels = model. evaluate( test_churn)
pred_and_labels. predictions. show( 5 )
+--------------------+-----+--------------------+--------------------+----------+
| features|churn| rawPrediction| probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[28.0,8670.98,0.0...| 0|[7.28032028440204...|[0.99931150948212...| 0.0|
|[29.0,5900.78,1.0...| 0|[3.80245014943030...|[0.97817110695536...| 0.0|
|[29.0,9378.24,0.0...| 0|[4.42704540541525...|[0.98819136594128...| 0.0|
|[30.0,8874.83,0.0...| 0|[2.92386878493753...|[0.94901382190383...| 0.0|
|[30.0,10744.14,1....| 1|[1.56910959319232...|[0.82765663668636...| 0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows
churn_eval = BinaryClassificationEvaluator( rawPredictionCol= 'prediction' , labelCol= 'churn' )
churn_eval_multi = MulticlassClassificationEvaluator( predictionCol= 'prediction' , labelCol= 'churn' , metricName= 'accuracy' )
auc = churn_eval_multi. evaluate( pred_and_labels. predictions)
auc
0.918918918918919
(6)模型预测
churn_test = model. transform( test_churn. select( 'features' ) )
churn_test. show( 5 )
+--------------------+--------------------+--------------------+----------+
| features| rawPrediction| probability|prediction|
+--------------------+--------------------+--------------------+----------+
|[28.0,8670.98,0.0...|[7.28032028440204...|[0.99931150948212...| 0.0|
|[29.0,5900.78,1.0...|[3.80245014943030...|[0.97817110695536...| 0.0|
|[29.0,9378.24,0.0...|[4.42704540541525...|[0.98819136594128...| 0.0|
|[30.0,8874.83,0.0...|[2.92386878493753...|[0.94901382190383...| 0.0|
|[30.0,10744.14,1....|[1.56910959319232...|[0.82765663668636...| 0.0|
+--------------------+--------------------+--------------------+----------+
only showing top 5 rows
方法二:决策树模型
from pyspark. ml. classification import RandomForestClassifier, DecisionTreeClassifier
dtc = DecisionTreeClassifier( labelCol= 'churn' , featuresCol= 'features' )
dtc_model = dtc. fit( train_churn)
print ( dtc_model. featureImportances)
(5,[0,1,3,4],[0.09646621280342624,0.09365722250595962,0.14583722780378533,0.6640393368868287])
predictions = dtc_model. transform( test_churn)
accuracy = churn_eval_multi. evaluate( predictions)
accuracy
0.9081081081081082
方法三:随机森林模型
rfc = RandomForestClassifier( labelCol= "churn" , featuresCol= "features" , numTrees= 20 )
rfc_model = rfc. fit( train_churn)
print ( rfc_model. featureImportances)
(5,[0,1,2,3,4],[0.09112246833941745,0.07289697555412486,0.00807666535024141,0.1859341040975848,0.6419697866586315])
predictions = rfc_model. transform( test_churn)
accuracy = churn_eval_multi. evaluate( predictions)
accuracy
0.9081081081081082
方法四:梯度提升树模型
from pyspark. ml. classification import GBTClassifier
gbt = GBTClassifier( labelCol= "churn" , featuresCol= "features" , maxIter= 20 )
gbt_model = gbt. fit( train_churn)
predictions = gbt_model. transform( test_churn)
accuracy = churn_eval_multi. evaluate( predictions)
accuracy
0.9027027027027027
方法五:逻辑回归模型➕交叉验证
from pyspark. ml. classification import LogisticRegression
blor = LogisticRegression( featuresCol= 'features' , labelCol= 'churn' , family= 'binomial' )
from pyspark. ml. tuning import ParamGridBuilder
param_grid = ParamGridBuilder( ) . \
addGrid( blor. regParam, [ 0 , 0.5 , 1 , 2 ] ) . \
addGrid( blor. elasticNetParam, [ 0 , 0.5 , 1 ] ) . \
build( )
from pyspark. ml. evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator( )
from pyspark. ml. tuning import CrossValidator
cv = CrossValidator( estimator= blor, estimatorParamMaps= param_grid, evaluator= churn_eval_multi, numFolds= 4 )
cvModel = cv. fit( train_churn)
cvModel. bestModel. intercept
-18.032119312923868
cvModel. bestModel. coefficients
DenseVector([0.0523, 0.0, 0.4303, 0.5307, 1.1368])
cvModel. bestModel. _java_obj. getRegParam( )
0.0
cvModel. bestModel. _java_obj. getElasticNetParam( )
0.0
predictions = cvModel. transform( test_churn)
accuracy = churn_eval_multi. evaluate( predictions)
accuracy
0.918918918918919
补充:混淆矩阵
label_pred_train = predictions. select( 'churn' , 'prediction' )
label_pred_train. rdd. zipWithIndex( ) . countByKey( )
defaultdict(int,
{Row(churn=0, prediction=0.0): 151,
Row(churn=1, prediction=0.0): 9,
Row(churn=1, prediction=1.0): 19,
Row(churn=0, prediction=1.0): 6})
3、聚类(Clustering)
data = spark. read. csv( "hack_data.csv" , header= True , inferSchema= True )
data. printSchema( )
root
|-- Session_Connection_Time: double (nullable = true)
|-- Bytes Transferred: double (nullable = true)
|-- Kali_Trace_Used: integer (nullable = true)
|-- Servers_Corrupted: double (nullable = true)
|-- Pages_Corrupted: double (nullable = true)
|-- Location: string (nullable = true)
|-- WPM_Typing_Speed: double (nullable = true)
data. select( 'Session_Connection_Time' , 'Bytes Transferred' , 'Kali_Trace_Used' , 'Servers_Corrupted' , 'Pages_Corrupted' , 'WPM_Typing_Speed' ) . show( 5 )
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|WPM_Typing_Speed|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
| 8.0| 391.09| 1| 2.96| 7.0| 72.37|
| 20.0| 720.99| 0| 3.04| 9.0| 69.08|
| 31.0| 356.32| 1| 3.71| 8.0| 70.58|
| 2.0| 228.08| 1| 2.48| 8.0| 70.8|
| 20.0| 408.5| 0| 3.57| 8.0| 71.28|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
only showing top 5 rows
data. columns
['Session_Connection_Time',
'Bytes Transferred',
'Kali_Trace_Used',
'Servers_Corrupted',
'Pages_Corrupted',
'Location',
'WPM_Typing_Speed']
(1)选择列作为模型输入特征
from pyspark. ml. linalg import Vectors
from pyspark. ml. feature import VectorAssembler
feat_cols = [ 'Session_Connection_Time' , 'Bytes Transferred' , 'Kali_Trace_Used' ,
'Servers_Corrupted' , 'Pages_Corrupted' , 'WPM_Typing_Speed' ]
vec_assembler = VectorAssembler( inputCols = feat_cols, outputCol= 'features' )
final_data = vec_assembler. transform( data)
final_data. select( 'features' ) . head( 1 ) [ 0 ]
Row(features=DenseVector([8.0, 391.09, 1.0, 2.96, 7.0, 72.37]))
(2)特征标准化
from pyspark. ml. feature import StandardScaler
scaler = StandardScaler( inputCol= "features" , outputCol= "scaledFeatures" , withStd= True , withMean= False )
cluster_final_data = scaler. fit( final_data) . transform( final_data)
cluster_final_data. select( "scaledFeatures" ) . show( 5 )
+--------------------+
| scaledFeatures|
+--------------------+
|[0.56785108466505...|
|[1.41962771166263...|
|[2.20042295307707...|
|[0.14196277116626...|
|[1.41962771166263...|
+--------------------+
only showing top 5 rows
cluster_final_data. select( "scaledFeatures" ) . head( 1 ) [ 0 ]
Row(scaledFeatures=DenseVector([0.5679, 1.3658, 1.9976, 1.2859, 2.2849, 5.3963]))
(3)K-Means 聚类
from pyspark. ml. clustering import KMeans
model = KMeans( featuresCol= 'scaledFeatures' , k= 3 )
model = model. fit( cluster_final_data)
model. computeCost( cluster_final_data)
434.1492898715845
model. clusterCenters( )
[array([1.30217042, 1.25830099, 0. , 1.35793211, 2.57251009,
5.24230473]),
array([2.99991988, 2.92319035, 1.05261534, 3.20390443, 4.51321315,
3.28474 ]),
array([1.21780112, 1.37901802, 1.99757683, 1.37198977, 2.55237797,
5.29152222])]
(4)模型预测
model. transform( cluster_final_data) . groupBy( 'prediction' ) . count( ) . show( )
+----------+-----+
|prediction|count|
+----------+-----+
| 1| 167|
| 2| 83|
| 0| 84|
+----------+-----+
model. transform( cluster_final_data) . show( 5 )
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted| Location|WPM_Typing_Speed| features| scaledFeatures|prediction|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
| 8.0| 391.09| 1| 2.96| 7.0| Slovenia| 72.37|[8.0,391.09,1.0,2...|[0.56785108466505...| 2|
| 20.0| 720.99| 0| 3.04| 9.0|British Virgin Is...| 69.08|[20.0,720.99,0.0,...|[1.41962771166263...| 0|
| 31.0| 356.32| 1| 3.71| 8.0| Tokelau| 70.58|[31.0,356.32,1.0,...|[2.20042295307707...| 2|
| 2.0| 228.08| 1| 2.48| 8.0| Bolivia| 70.8|[2.0,228.08,1.0,2...|[0.14196277116626...| 2|
| 20.0| 408.5| 0| 3.57| 8.0| Iraq| 71.28|[20.0,408.5,0.0,3...|[1.41962771166263...| 0|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
only showing top 5 rows
4、基于 TF-IDF 算法的文本挖掘
data = spark. read. csv( "SMSSpamCollection" , inferSchema= True , sep= '\t' )
data = data. withColumnRenamed( '_c0' , 'class' ) . withColumnRenamed( '_c1' , 'text' )
data. show( 5 )
+-----+--------------------+
|class| text|
+-----+--------------------+
| ham|Go until jurong p...|
| ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
| ham|U dun say so earl...|
| ham|Nah I don't think...|
+-----+--------------------+
only showing top 5 rows
(1)数据预处理
from pyspark. sql. functions import length
data = data. withColumn( 'length' , length( data[ 'text' ] ) )
(1.1)分词
from pyspark. ml. feature import Tokenizer, StopWordsRemover, CountVectorizer, IDF, StringIndexer
tokenizer = Tokenizer( inputCol= "text" , outputCol= "stop_tokens" )
(1.2)去除停用词
(1.3)计算词频
count_vec = CountVectorizer( inputCol= 'stop_tokens' , outputCol= 'c_vec' )
(1.4)计算逆文本频率
idf = IDF( inputCol= "c_vec" , outputCol= "tf_idf" )
(1.5)将类标签由字符串映射到索引
ham_spam_to_num = StringIndexer( inputCol= 'class' , outputCol= 'label' )
(2)将列转化为模型输入特征
from pyspark. ml. feature import VectorAssembler
from pyspark. ml. linalg import Vector
clean_up = VectorAssembler( inputCols= [ 'tf_idf' , 'length' ] , outputCol= 'features' )
(3)构建模型
from pyspark. ml. classification import NaiveBayes
nb = NaiveBayes( )
(4)构建 pipeline
from pyspark. ml import Pipeline
data_prep_pipe = Pipeline( stages= [ ham_spam_to_num, tokenizer, count_vec, idf, clean_up] )
cleaner = data_prep_pipe. fit( data)
clean_data = cleaner. transform( data)
(5)划分训练集和测试集
full_data = clean_data. select( [ 'label' , 'features' ] )
( train_data, test_data) = full_data. randomSplit( [ 0.8 , 0.2 ] )
(6)模型训练
model = nb. fit( train_data)
test_results = model. transform( test_data)
test_results. show( 5 )
+-----+--------------------+--------------------+--------------------+----------+
|label| features| rawPrediction| probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
| 0.0|(13588,[0,1,2,3,4...|[-1350.8171609962...|[1.0,1.0733004084...| 0.0|
| 0.0|(13588,[0,1,2,3,4...|[-3071.3460250107...|[1.0,1.4929420982...| 0.0|
| 0.0|(13588,[0,1,2,3,4...|[-1454.8011163433...|[1.0,1.8318571536...| 0.0|
| 0.0|(13588,[0,1,2,3,4...|[-1169.5412775216...|[1.0,5.0678369468...| 0.0|
| 0.0|(13588,[0,1,2,3,5...|[-1769.7764271667...|[1.0,2.5959352248...| 0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows
(7)模型评估
from pyspark. ml. evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator( )
acc = acc_eval. evaluate( test_results)
print ( "Accuracy of model at predicting spam was: {}" . format ( acc) )
Accuracy of model at predicting spam was: 0.9416633505993651
spark. stop( )