参考地址:https://github.com/jadianes/spark-py-notebooks
Spark SQL and Data Frames
#!/usr/bin/python # -*- coding: UTF-8 -*- import urllib from pyspark import SparkContext,SparkConf f = urllib.urlretrieve ("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data_10_percent.gz", "kddcup.data_10_percent.gz") data_file = "./kddcup.data_10_percent.gz" sc = SparkContext(conf=SparkConf().setAppName("The first example")) # Creating a RDD from a file raw_data = sc.textFile(data_file) print "Train data size is {}".format(raw_data.count()) # Getting a Data Frame from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Inferring the schema from pyspark.sql import Row csv_data = raw_data.map(lambda l: l.split(",")) row_data = csv_data.map(lambda p: Row( duration=int(p[0]), protocol_type=p[1], service=p[2], flag=p[3], src_bytes=int(p[4]), dst_bytes=int(p[5]) ) ) # RDD interactions_df = sqlContext.createDataFrame(row_data) # DF RDD-->DF interactions_df.registerTempTable("interactions") # Select tcp network interactions with more than 1 second duration and no transfer from destination tcp_interactions = sqlContext.sql(""" SELECT duration, dst_bytes FROM interactions WHERE protocol_type = 'tcp' AND duration > 1000 AND dst_bytes = 0 """) tcp_interactions.show() # Output duration together with dst_bytes tcp_interactions_out = tcp_interactions.map(lambda p: "Duration: {}, Dest. bytes: {}".format(p.duration, p.dst_bytes)) for ti_out in tcp_interactions_out.collect(): print ti_out interactions_df.printSchema() # Queries as DataFrame operations from time import time t0 = time() interactions_df.select("protocol_type", "duration", "dst_bytes").groupBy("protocol_type").count().show() tt = time() - t0 print "Query performed in {} seconds".format(round(tt,3)) t0 = time() interactions_df.select("protocol_type", "duration", "dst_bytes").filter(interactions_df.duration>1000).filter(interactions_df.dst_bytes==0).groupBy("protocol_type").count().show() tt = time() - t0 print "Query performed in {} seconds".format(round(tt,3)) def get_label_type(label): if label != "normal.": return "attack" else: return "normal" row_labeled_data = csv_data.map(lambda p: Row( duration=int(p[0]), protocol_type=p[1], service=p[2], flag=p[3], src_bytes=int(p[4]), dst_bytes=int(p[5]), label=get_label_type(p[41]) ) ) interactions_labeled_df = sqlContext.createDataFrame(row_labeled_data) t0 = time() interactions_labeled_df.select("label").groupBy("label").count().show() tt = time() - t0 print "Query performed in {} seconds".format(round(tt,3)) t0 = time() interactions_labeled_df.select("label", "protocol_type").groupBy("label", "protocol_type").count().show() tt = time() - t0 print "Query performed in {} seconds".format(round(tt,3)) t0 = time() interactions_labeled_df.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", interactions_labeled_df.dst_bytes==0).count().show() tt = time() - t0 print "Query performed in {} seconds".format(round(tt,3))