场景:有一张部门表,和员工表,如何获得员工的一级部门(一级部门的上级部门ID为0)
部门ID | 部门名称 | 上级部门ID |
001 | 市场部 | 0 |
002 | 国内市场部 | 001 |
003 | 海外市场部 | 001 |
004 | 美国市场部 | 003 |
员工ID | 员工名称 | 部门ID |
A01 | 张三 | 003 |
A02 | 李四 | 004 |
方法:
- 方法一:暴力关联展开
- 需要先确认最高有多少级部门,有多少级部门就自关联多少次,保证最底层的部门可以关联到最终的一级部门ID
SELECT * FROM department a LEFT JOIN department b ON a.parentid=b.departmentid LEFT JOIN department c ON b.parentid=c.departmentid;
- 采用case when获得每个部门的一级部门
SELECT a.departmentid ,CASE WHEN a.parentid=0 OR a.parentid=NULL THEN a.department WHEN b.parentid=0 OR b.parentid=NULL THEN b.department WHEN c.parentid=0 OR c.parentid=NULL THEN c.department ELSE NULL END AS first_department FROM department a LEFT JOIN department b ON a.parentid=b.departmentid LEFT JOIN department c ON b.parentid=c.departmentid;
- 使用人员表关联这张新的一级部门表即可得到每个员工的一级部门
- 需要先确认最高有多少级部门,有多少级部门就自关联多少次,保证最底层的部门可以关联到最终的一级部门ID
- 方法二:SparkSQL递归实现
使用pyspark写递归函数,不断left join,直至获得每个部门的一级部门,保存在新表中,再使用人员表关联这张新的一级部门表即可得到每个员工的一级部门
# -*- coding: utf-8 -*-
# 代码中包含中文,需要转码为utf-8
import sys
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import isnull
from pyspark.sql.functions import when
try:
# for python 2
# 表中数据包含中文,需要转码为utf-8
reload(sys)
sys.setdefaultencoding('utf8')
except:
# python 3 not needed
pass
if __name__ == '__main__':
# 配置SparkSession
spark = SparkSession.builder.appName("spark sql")\
.config("spark.debug.maxToStringFields", "200")\
.config("spark.sql.crossJoin.enabled","true")\
.getOrCreate()
# 读取表中数据
# 添加新的一列,first_department
df_a = spark.sql("SELECT departmentid as depart_id, departmentname, if(parentid=0,departmentname,null) as first_department, if(parentid is null,0,parentid) as parentid FROM department")
df_b = spark.sql("SELECT departmentid as depart_id, departmentname, if(parentid=0,departmentname,null) as first_department, if(parentid is null,0,parentid) as parentid FROM department")
df_a.join(df_b,df_a.parentid==df_b.depart_id,'leftouter').show()
#编写递归函数
def recursive_left_join(df_current, df_to_join):
# 每次使用上级部门的parentid,代替现部门parentid
df_joined = df_current.join(df_to_join, df_current.parentid == df_to_join.depart_id, "leftouter").select(df_current.depart_id,df_current.departmentname,when(isnull(df_to_join.departmentname),df_current.first_department).otherwise(df_to_join.departmentname).alias("first_department"),when(isnull(df_to_join.parentid),df_current.parentid).otherwise(df_to_join.parentid).alias("parentid"))
print(df_joined.where("parentid=0").count())
# 直至所有的parentid=0,表示当前的first_department都是一级部门,完成递归
if df_joined.where("parentid=0").count() == df_joined.count():
return df_joined
return recursive_left_join(df_joined, df_to_join)
#调用递归函数
df_result = recursive_left_join(df_a, df_b)
df_select = df_result.select(df_result.depart_id,df_result.departmentname,df_result.first_department)
#建表
spark.sql("DROP TABLE IF EXISTS spark_sql_department_table")
#插入
df_select.write.mode("overwrite").saveAsTable("spark_sql_department_table")
departmentid | departmentname | first_department |
001 | 市场部 | 市场部 |
002 | 国内市场部 | 市场部 |
003 | 海外市场部 | 市场部 |
004 | 美国市场部 | 市场部 |
参考文档: