PySpark之SparkSQL基本操作

SparkSQL

DataFrame的优点:

  • DataFrame的推出,让Spark具备了处理大规模结构化数据的能力,不仅比原有的RDD转化方式更加简单易用,而且获得了更高的计算性能
  • Spark能够轻松实现从MySQL到DataFrame的转化,并且支持SQL查询

DataFrame和RDD的区别:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nlGztlnH-1620287012285)(./imgs/dataframe.png)]

  • RDD是分布式的Java对象的集合,但是,对象内部结构对于RDD而言却是不可知的
  • DataFrame是一种以RDD为基础的分布式数据集,提供了详细的结构信息
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql import functions as fn
import os


# 创建SparkSession实现其对数据加载、转换、处理等功能
spark = SparkSession.builder.appName("test").getOrCreate()
sc = spark.sparkContext
ROOT_PATH = "file://" + os.getcwd()

RDD与DataFrame互相转换

# 创建DataFrame
l = [('Ankit',25),('Jalfaizy',22),('saurabh',20),('Bala',26)]
rdd = sc.parallelize(l)

# 加入键值对,即为数据加入列名
people =rdd.map(lambda x:Row(name = x[0], age = int(x[1])))  
people.collect()
[Row(name='Ankit', age=25),
 Row(name='Jalfaizy', age=22),
 Row(name='saurabh', age=20),
 Row(name='Bala', age=26)]
# 使用对象的集合,需要使用对象调用属性
row = Row(name="Alice", age=11)
row.name
'Alice'
schemapeople = spark.createDataFrame(people)
schemapeople.show()
+--------+---+
|    name|age|
+--------+---+
|   Ankit| 25|
|Jalfaizy| 22|
| saurabh| 20|
|    Bala| 26|
+--------+---+
# 创建临时表,使用SQL语句进行查询
schemapeople.createOrReplaceTempView("person")
spark.sql("select * from person")
DataFrame[name: string, age: bigint]
schemapeople.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)
# DateFrame转换为RDD格式数据
schemapeople.rdd.collect()
[Row(name='Ankit', age=25),
 Row(name='Jalfaizy', age=22),
 Row(name='saurabh', age=20),
 Row(name='Bala', age=26)]

SparkSQL基本操作

CSV格式文件处理

# 读csv格式文件
# 方法一
# df = spark.read.format("csv").option("header","true").load("./data/iris.csv")
# 方法二:
df = spark.read.csv(ROOT_PATH + "/data/iris.csv", header=True)
df.show(5)
+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows
# 打印模式信息
df.printSchema()
root
 |-- sepal_length: string (nullable = true)
 |-- sepal_width: string (nullable = true)
 |-- petal_length: string (nullable = true)
 |-- petal_width: string (nullable = true)
 |-- species: string (nullable = true)
# 列名
df.columns
['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
# 总数据
df.count()
150
# 默认删除列,传入为可变参数
df.drop(*["species", "sepal_width"]).show(5)
+------------+------------+-----------+
|sepal_length|petal_length|petal_width|
+------------+------------+-----------+
|         5.1|         1.4|        0.2|
|         4.9|         1.4|        0.2|
|         4.7|         1.3|        0.2|
|         4.6|         1.5|        0.2|
|         5.0|         1.4|        0.2|
+------------+------------+-----------+
only showing top 5 rows
# 描述信息
df.describe().show()
+-------+------------------+-------------------+------------------+------------------+---------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|  species|
+-------+------------------+-------------------+------------------+------------------+---------+
|  count|               150|                150|               150|               150|      150|
|   mean| 5.843333333333335|  3.057333333333334|3.7580000000000027| 1.199333333333334|     null|
| stddev|0.8280661279778637|0.43586628493669793|1.7652982332594662|0.7622376689603467|     null|
|    min|               4.3|                2.0|               1.0|               0.1|   setosa|
|    max|               7.9|                4.4|               6.9|               2.5|virginica|
+-------+------------------+-------------------+------------------+------------------+---------+
# df.describe(*cols)
df.describe("petal_length").show()
+-------+------------------+
|summary|      petal_length|
+-------+------------------+
|  count|               150|
|   mean|3.7580000000000027|
| stddev|1.7652982332594662|
|    min|               1.0|
|    max|               6.9|
+-------+------------------+
# 选取指定列数据 以ROW类型数据格式进行数据存储
df.select("*").collect()[:3]
[Row(sepal_length='5.1', sepal_width='3.5', petal_length='1.4', petal_width='0.2', species='setosa'),
 Row(sepal_length='4.9', sepal_width='3.0', petal_length='1.4', petal_width='0.2', species='setosa'),
 Row(sepal_length='4.7', sepal_width='3.2', petal_length='1.3', petal_width='0.2', species='setosa')]
#  选几列数据
df.select(*["sepal_width", "sepal_length"]).show(5)
+-----------+------------+
|sepal_width|sepal_length|
+-----------+------------+
|        3.5|         5.1|
|        3.0|         4.9|
|        3.2|         4.7|
|        3.1|         4.6|
|        3.6|         5.0|
+-----------+------------+
only showing top 5 rows
# 显著分析
df.select(["species"]).distinct().show()
+----------+
|   species|
+----------+
| virginica|
|versicolor|
|    setosa|
+----------+
# 常用聚合函数
# avg(), count(), countDistinct(), first(), kurtosis(),
# max(), mean(), min(), skewness(), stddev(), stddev_pop(),
# stddev_samp(), sum(), sumDistinct(), var_pop(), var_samp() and variance()

# 注意:同一列数据只能够统计一次
df.groupBy("species").agg({"sepal_length":"mean", "sepal_width":"max" , "sepal_length":"variance"}).show()
+----------+----------------+----------------------+
|   species|max(sepal_width)|variance(sepal_length)|
+----------+----------------+----------------------+
| virginica|             3.8|    0.4043428571428571|
|versicolor|             3.4|    0.2664326530612246|
|    setosa|             4.4|   0.12424897959183674|
+----------+----------------+----------------------+
# fn 同一列数据可以统计多次,列名不同,聚合类型相同
df.agg(fn.count("sepal_length").alias("sepal_length_count"), fn.mean("sepal_width"), fn.max("sepal_width"), fn.countDistinct("species")).show()
+------------------+------------------+----------------+--------------+
|sepal_length_count|  avg(sepal_width)|max(sepal_width)|count(species)|
+------------------+------------------+----------------+--------------+
|               150|3.0573333333333337|             4.4|             3|
+------------------+------------------+----------------+--------------+
# 拆分数据集
train, val = df.randomSplit([0.7, 0.3], seed=0)
# 70%做训练集
train.count()/df.count()
0.7266666666666667
# 查看样本分配情况
train.select("species").distinct().show()
+----------+
|   species|
+----------+
| virginica|
|versicolor|
|    setosa|
+----------+
# 可以转为pandas数据处理
val.select("species").distinct().toPandas()
species
0virginica
1versicolor
2setosa
# 做减法,进行比较
val.select("species").subtract(train.select("species")).show()
+-------+
|species|
+-------+
+-------+
# 采样数据集
sample_data = df.sample(False, fraction = 0.1, seed = 3)
sample_data.count()
13
# 交叉表,共现频率
df.crosstab("species", "sepal_width").toPandas()
species_sepal_width2.02.22.32.42.52.62.72.82.9...3.43.53.63.73.83.94.04.14.24.4
0virginica010042482...2010200000
1setosa001000001...9633421111
2versicolor123343567...1000000000

3 rows × 24 columns

trainDF.select('species').subtract(testDF.select('species')).distinct().rdd.map(lambda x:[0]).collect()
[[0]]
# 使用临时表进行简单的SQL查询
trainDF.createOrReplaceTempView("train_table")
# 返回一个DateFrame
spark.sql("select petal_length from train_table where petal_length > 1.5").show(5)
+------------+
|petal_length|
+------------+
|         1.6|
|         1.6|
|         1.6|
|         1.9|
|         3.3|
+------------+
only showing top 5 rows
# 必须写入全路径, 加上file://
# 方法一:
# trainDF.write.json((ROOT_PATH + "/data/trainDF.json")
# 方法二:
trainDF.write.format("json").save(ROOT_PATH + "/data/trainDF.json")

Json格式文件处理

Json文件会自动检测数据类型

jsonString = [
"""{ "id" : "01001", "city" : "AGAWAM",  "pop" : 15338, "state" : "MA" }""",
"""{ "id" : "01002", "city" : "CUSHMAN", "pop" : 36963, "state" : "MA" }"""
]
jsonString
['{ "id" : "01001", "city" : "AGAWAM",  "pop" : 15338, "state" : "MA" }',
 '{ "id" : "01002", "city" : "CUSHMAN", "pop" : 36963, "state" : "MA" }']
jsonRDD = sc.parallelize(jsonString)
jsonRDD.collect()
['{ "id" : "01001", "city" : "AGAWAM",  "pop" : 15338, "state" : "MA" }',
 '{ "id" : "01002", "city" : "CUSHMAN", "pop" : 36963, "state" : "MA" }']
jsonDF = spark.read.json(jsonRDD)
jsonDF.show()
+-------+-----+-----+-----+
|   city|   id|  pop|state|
+-------+-----+-----+-----+
| AGAWAM|01001|15338|   MA|
|CUSHMAN|01002|36963|   MA|
+-------+-----+-----+-----+
jsonDF.printSchema()  # 着重区别与csv的格式数据的读入格式类型
root
 |-- city: string (nullable = true)
 |-- id: string (nullable = true)
 |-- pop: long (nullable = true)
 |-- state: string (nullable = true)
jsonDF = spark.read.json(ROOT_PATH + "/data/province.json")
jsonDF.show(5)
+----+--------+------------+
|code|    name|provinceCode|
+----+--------+------------+
|1101|  市辖区|          11|
|1201|  市辖区|          12|
|1301|石家庄市|          13|
|1302|  唐山市|          13|
|1303|秦皇岛市|          13|
+----+--------+------------+
only showing top 5 rows
jsonDF.printSchema()
root
 |-- code: string (nullable = true)
 |-- name: string (nullable = true)
 |-- provinceCode: string (nullable = true)
jsonDF.describe().show()
+-------+-----------------+--------+------------------+
|summary|             code|    name|      provinceCode|
+-------+-----------------+--------+------------------+
|  count|              342|     342|               342|
|   mean|4046.502923976608|    null| 40.35672514619883|
| stddev|1492.861959917837|    null|14.898333817383897|
|    min|             1101|七台河市|                11|
|    max|             6590|  龙岩市|                65|
+-------+-----------------+--------+------------------+
jsonDF.count()
342
jsonDF.filter(jsonDF.code>1440).show(4)
+----+----------+------------+
|code|      name|provinceCode|
+----+----------+------------+
|1501|呼和浩特市|          15|
|1502|    包头市|          15|
|1503|    乌海市|          15|
|1504|    赤峰市|          15|
+----+----------+------------+
only showing top 4 rows
# 更改数据为指定的数据类型
jsonString = [
"""{ "id" : "01001", "city" : "AGAWAM",  "pop" : 15338, "state" : "MA" }""",
"""{ "id" : "01002", "city" : "CUSHMAN", "pop" : 36963, "state" : "MA" }"""
]

jsonRDD = sc.parallelize(jsonString)
jsonRDD.collect()
['{ "id" : "01001", "city" : "AGAWAM",  "pop" : 15338, "state" : "MA" }',
 '{ "id" : "01002", "city" : "CUSHMAN", "pop" : 36963, "state" : "MA" }']
from pyspark.sql.types import *
#定义结构类型(方法一)
#StructType:schema的整体结构,表示JSON的对象结构
#XXXStype:指的是某一列的数据类型
jsonSchema = StructType() \
  .add("id", StringType(),True) \
  .add("city", StringType()) \
  .add("pop" , LongType()) \
  .add("state",StringType())
reader = spark.read.schema(jsonSchema)
reader.json(jsonRDD).show()
+-----+-------+-----+-----+
|   id|   city|  pop|state|
+-----+-------+-----+-----+
|01001| AGAWAM|15338|   MA|
|01002|CUSHMAN|36963|   MA|
+-----+-------+-----+-----+
# 读取数据类型
reader.json(jsonRDD).printSchema()
root
 |-- id: string (nullable = true)
 |-- city: string (nullable = true)
 |-- pop: long (nullable = true)
 |-- state: string (nullable = true)
# 改变输入数据的类型(方法二)
scheme = StructType([StructField("id", StringType())
                     , StructField("city", StringType())
                     , StructField("pop", LongType())
                     , StructField("state", StringType())
                    ]
                   )
data1 = spark.read.json(jsonRDD, schema=scheme)
data1.show()
+-----+-------+-----+-----+
|   id|   city|  pop|state|
+-----+-------+-----+-----+
|01001| AGAWAM|15338|   MA|
|01002|CUSHMAN|36963|   MA|
+-----+-------+-----+-----+
data1.printSchema()
root
 |-- id: string (nullable = true)
 |-- city: string (nullable = true)
 |-- pop: long (nullable = true)
 |-- state: string (nullable = true)

数据清洗

去重

df = spark.createDataFrame([
    
    (1, 144.5, 5.9, 33, 'M'),
    (2, 167.2, 5.4, 45, 'M'),
    (3, 124.1, 5.2, 23, 'F'),
    (4, 144.5, 5.9, 33, 'M'),
    (5, 133.2, 5.7, 54, 'F'),
    (3, 124.1, 5.2, 23, 'F'),
    (5, 129.2, 5.3, 42, 'M'),]
    , ['id', 'weight', 'height', 'age', 'gender']
)
df.show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
|  1| 144.5|   5.9| 33|     M|
|  2| 167.2|   5.4| 45|     M|
|  3| 124.1|   5.2| 23|     F|
|  4| 144.5|   5.9| 33|     M|
|  5| 133.2|   5.7| 54|     F|
|  3| 124.1|   5.2| 23|     F|
|  5| 129.2|   5.3| 42|     M|
+---+------+------+---+------+
# 去重,默认消除相同的行(行元素相同), id=3(全部)
df.dropDuplicates().show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
|  5| 133.2|   5.7| 54|     F|
|  5| 129.2|   5.3| 42|     M|
|  1| 144.5|   5.9| 33|     M|
|  4| 144.5|   5.9| 33|     M|
|  2| 167.2|   5.4| 45|     M|
|  3| 124.1|   5.2| 23|     F|
+---+------+------+---+------+
# 删除无意义字段之外的其他其他字段的重复数据(部分,即部分子字段内重值进行删除, 例如id=1和id=4)
df3 = df.drop_duplicates(subset=[c for c in df.columns if c!="id"])
df3.show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
|  5| 133.2|   5.7| 54|     F|
|  1| 144.5|   5.9| 33|     M|
|  2| 167.2|   5.4| 45|     M|
|  3| 124.1|   5.2| 23|     F|
|  5| 129.2|   5.3| 42|     M|
+---+------+------+---+------+
# 聚合某些特定的值, alias重命名数据格式
df3.agg(fn.count("id").alias("id_count"), fn.countDistinct("id").alias("distinct_id_count")).show()
+--------+-----------------+
|id_count|distinct_id_count|
+--------+-----------------+
|       5|                4|
+--------+-----------------+
# 添加新的一列数据index索引值,随机没有重复的索引值
df3.withColumn("new_id", fn.monotonically_increasing_id()).show()
+---+------+------+---+------+-------------+
| id|weight|height|age|gender|       new_id|
+---+------+------+---+------+-------------+
|  5| 133.2|   5.7| 54|     F|  25769803776|
|  1| 144.5|   5.9| 33|     M| 171798691840|
|  2| 167.2|   5.4| 45|     M| 592705486848|
|  3| 124.1|   5.2| 23|     F|1236950581248|
|  5| 129.2|   5.3| 42|     M|1365799600128|
+---+------+------+---+------+-------------+

缺失值处理

df_miss = spark.createDataFrame([
    (1, 143.5, 5.6, 28,'M', 100000),
    (2, 167.2, 5.4, 45,'M', None),
    (3, None , 5.2, None, None, None),
    (4, 144.5, 5.9, 33, 'M', None),
    (5, 133.2, 5.7, 54, 'F', None),
    (6, 124.1, 5.2, None, 'F', None),
    (7, 129.2, 5.3, 42, 'M', 76000),]
    , ['id', 'weight', 'height', 'age', 'gender', 'income']
)
df_miss.show()
+---+------+------+----+------+------+
| id|weight|height| age|gender|income|
+---+------+------+----+------+------+
|  1| 143.5|   5.6|  28|     M|100000|
|  2| 167.2|   5.4|  45|     M|  null|
|  3|  null|   5.2|null|  null|  null|
|  4| 144.5|   5.9|  33|     M|  null|
|  5| 133.2|   5.7|  54|     F|  null|
|  6| 124.1|   5.2|null|     F|  null|
|  7| 129.2|   5.3|  42|     M| 76000|
+---+------+------+----+------+------+
# 统计每一行的缺失值情况,转为rdd数据格式后,执行transformation操作, sum只对True进行统计
df_miss.rdd.map(lambda row:(row["id"], sum(c==None for c in row))).collect()
[(1, 0), (2, 1), (3, 4), (4, 1), (5, 1), (6, 2), (7, 0)]
# 统计列缺失情况
df_miss.agg(*[(1-(fn.count(c)/fn.count("*"))).alias(c+"_missing") for c in df_miss.columns]).show()
+----------+------------------+--------------+------------------+------------------+------------------+
|id_missing|    weight_missing|height_missing|       age_missing|    gender_missing|    income_missing|
+----------+------------------+--------------+------------------+------------------+------------------+
|       0.0|0.1428571428571429|           0.0|0.2857142857142857|0.1428571428571429|0.7142857142857143|
+----------+------------------+--------------+------------------+------------------+------------------+
# "*"表示统计所有的数据
df_miss.agg(fn.count("*")).show()
+--------+
|count(1)|
+--------+
|       7|
+--------+
# 删除缺失值严重的列
df_miss1 = df_miss.select([c for c in df_miss.columns if c!="income"])
df_miss1.show()
+---+------+------+----+------+
| id|weight|height| age|gender|
+---+------+------+----+------+
|  1| 143.5|   5.6|  28|     M|
|  2| 167.2|   5.4|  45|     M|
|  3|  null|   5.2|null|  null|
|  4| 144.5|   5.9|  33|     M|
|  5| 133.2|   5.7|  54|     F|
|  6| 124.1|   5.2|null|     F|
|  7| 129.2|   5.3|  42|     M|
+---+------+------+----+------+
# 删除缺存在值超过指定阈值的行数据,只要有超过thresh
(df_miss1.dropna(thresh=3).show(), df_miss1.dropna(thresh=5).show())
+---+------+------+----+------+
| id|weight|height| age|gender|
+---+------+------+----+------+
|  1| 143.5|   5.6|  28|     M|
|  2| 167.2|   5.4|  45|     M|
|  4| 144.5|   5.9|  33|     M|
|  5| 133.2|   5.7|  54|     F|
|  6| 124.1|   5.2|null|     F|
|  7| 129.2|   5.3|  42|     M|
+---+------+------+----+------+

+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
|  1| 143.5|   5.6| 28|     M|
|  2| 167.2|   5.4| 45|     M|
|  4| 144.5|   5.9| 33|     M|
|  5| 133.2|   5.7| 54|     F|
|  7| 129.2|   5.3| 42|     M|
+---+------+------+---+------+
# 指定字段进行填充指定值
df_miss.show(), df_miss.fillna({"weight":150, "age":18, "gender":"sex", "income":6868}).show()
+---+------+------+----+------+------+
| id|weight|height| age|gender|income|
+---+------+------+----+------+------+
|  1| 143.5|   5.6|  28|     M|100000|
|  2| 167.2|   5.4|  45|     M|  null|
|  3|  null|   5.2|null|  null|  null|
|  4| 144.5|   5.9|  33|     M|  null|
|  5| 133.2|   5.7|  54|     F|  null|
|  6| 124.1|   5.2|null|     F|  null|
|  7| 129.2|   5.3|  42|     M| 76000|
+---+------+------+----+------+------+

+---+------+------+---+------+------+
| id|weight|height|age|gender|income|
+---+------+------+---+------+------+
|  1| 143.5|   5.6| 28|     M|100000|
|  2| 167.2|   5.4| 45|     M|  6868|
|  3| 150.0|   5.2| 18|   sex|  6868|
|  4| 144.5|   5.9| 33|     M|  6868|
|  5| 133.2|   5.7| 54|     F|  6868|
|  6| 124.1|   5.2| 18|     F|  6868|
|  7| 129.2|   5.3| 42|     M| 76000|
+---+------+------+---+------+------+
# 使用均值进行填补.
# - 'records' : list like
# [{column -> value}, ... , {column -> value}]

bb = df_miss1.agg(*[fn.mean(c).alias(c) for c in df_miss1.columns if c!="gender"]).toPandas()
bb.to_dict("recoder")
[{'id': 4.0,
  'weight': 140.28333333333333,
  'height': 5.471428571428572,
  'age': 40.4}]
miss_fill = df_miss1.agg(*[fn.mean(c).alias(c) for c in df_miss1.columns if c!="gender"]).toPandas().to_dict("records")[0]
miss_fill["gender"] = "M"
miss_fill
{'id': 4.0,
 'weight': 140.28333333333333,
 'height': 5.471428571428572,
 'age': 40.4,
 'gender': 'M'}
# 使用字典对指定列数据进行缺失值填补
df_miss1.fillna(miss_fill).show()
+---+------------------+------+---+------+
| id|            weight|height|age|gender|
+---+------------------+------+---+------+
|  1|             143.5|   5.6| 28|     M|
|  2|             167.2|   5.4| 45|     M|
|  3|140.28333333333333|   5.2| 40|     M|
|  4|             144.5|   5.9| 33|     M|
|  5|             133.2|   5.7| 54|     F|
|  6|             124.1|   5.2| 40|     F|
|  7|             129.2|   5.3| 42|     M|
+---+------------------+------+---+------+

异常值处理

  • 异常值:不属于正常的值 包含:缺失值,超过正常范围内的较大值或较小值分位数去极值
  • 中位数绝对偏差去极值
  • 正态分布去极值

上述三种操作的核心都是:通过原始数据设定一个正常的范围,超过此范围的就是一个异常值

知识点:四分位数(处理异常值)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-civeic3I-1620287012290)(./imgs/四分位点.gif)]

一般,数据集中的最小值我们称之为下界,最大值称之为上界。
计算四分位数之前,第一步需要先求中位数M,
下四分位数Q1=数据集中所有数值由小到大排列后第25%的数字
上四分位数Q3=数据集中所有数值由小到大排列后第75%的数字
四分位数间距: I Q R = Q 3 − Q 1 IQR = Q3 - Q1 IQR=Q3Q1
下界(Min): Q 1 − 1.5 ∗ I Q R Q1 - 1.5*IQR Q11.5IQR
上界(Max): Q 3 + 1.5 ∗ I Q R Q3 + 1.5*IQR Q3+1.5IQR

df_outliers = spark.createDataFrame([
    (1, 143.5, 5.3, 28),
    (2, 154.2, 5.5, 45),
    (3, 342.3, 5.1, 99),
    (4, 144.5, 5.5, 33),
    (5, 133.2, 5.4, 54),
    (6, 124.1, 5.1, 21),
    (7, 129.2, 5.3, 42),
    ]
    , ['id', 'weight', 'height', 'age']
)
df_outliers.show()
+---+------+------+---+
| id|weight|height|age|
+---+------+------+---+
|  1| 143.5|   5.3| 28|
|  2| 154.2|   5.5| 45|
|  3| 342.3|   5.1| 99|
|  4| 144.5|   5.5| 33|
|  5| 133.2|   5.4| 54|
|  6| 124.1|   5.1| 21|
|  7| 129.2|   5.3| 42|
+---+------+------+---+
# 计算数值型数据的四分位数据点
cols = ["weight", "height", "age"]
bounds = {}

for col in cols:
    # 0.25, 0.75 分位数
    quantiles = df_outliers.approxQuantile(col, [0.25, 0.75], 0)
    
    # 极差
    IQR = quantiles[1] - quantiles[0]
    bounds[col] = [quantiles[0] - 1.5*IQR, quantiles[1] + 1.5*IQR]
bounds
{'weight': [91.69999999999999, 191.7],
 'height': [4.499999999999999, 6.1000000000000005],
 'age': [-11.0, 93.0]}
outliers = df_outliers.select(*['id'] + [((df_outliers[c] < bounds[c][0]) | (df_outliers[c] > bounds[c][1])).alias(c + '_o') for c in cols])
outliers.show()
+---+--------+--------+-----+
| id|weight_o|height_o|age_o|
+---+--------+--------+-----+
|  7|   false|   false|false|
|  6|   false|   false|false|
|  5|   false|   false|false|
|  1|   false|   false|false|
|  3|    true|   false| true|
|  2|   false|   false|false|
|  4|   false|   false|false|
+---+--------+--------+-----+
df_outliers.select("id", "weight").show()
+---+------+
| id|weight|
+---+------+
|  1| 143.5|
|  2| 154.2|
|  3| 342.3|
|  4| 144.5|
|  5| 133.2|
|  6| 124.1|
|  7| 129.2|
+---+------+
df_outliers = df_outliers.join(outliers, on="id")
df_outliers.show()
+---+------+------+---+--------+--------+-----+
| id|weight|height|age|weight_o|height_o|age_o|
+---+------+------+---+--------+--------+-----+
|  7| 129.2|   5.3| 42|   false|   false|false|
|  6| 124.1|   5.1| 21|   false|   false|false|
|  5| 133.2|   5.4| 54|   false|   false|false|
|  1| 143.5|   5.3| 28|   false|   false|false|
|  3| 342.3|   5.1| 99|    true|   false| true|
|  2| 154.2|   5.5| 45|   false|   false|false|
|  4| 144.5|   5.5| 33|   false|   false|false|
+---+------+------+---+--------+--------+-----+
df_outliers.filter("weight_o").select("id", "weight").show()
+---+------+
| id|weight|
+---+------+
|  3| 342.3|
+---+------+
df_outliers.filter("age_o").select("id", "age").show()
+---+---+
| id|age|
+---+---+
|  3| 99|
+---+---+

参考

基于PYSPARK创建DATAFRAME的几种方法
四分位数间距_四分位数计算公式_四分位数函数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值