需求:如下图 test.csv,dataframe 中每行都需要添加前边跟后边最近的的 SSSSSSS* 记录对应的值。
解决方案:
无法通过 lead, lag 等方法实现,因为开窗函数选定的数据框无法加上条件
思路 - 通过两次 Join 分别找到前后最近的 SSSSSS*记录,代码如下
var df = spark.read.option("header", "true").csv("C:\\Users\\XXX\\Desktop\\test.csv") // 读取文件 df = df.withColumn("seq_id", df.col("seq_id").cast(IntegerType)) // 将 seq_id 转为 Int 类型,防止后续按字符串排序 val schemas= Seq("group_key1", "seq_id1", "main_id1", "type1", "value1") val df1 = df.toDF(schemas: _*) // 复制一个 dataframe ,用于后续join val window = Window.partitionBy("group_key1", "seq_id").orderBy("seq_id1") // 注意哪里使用 seq_id 哪里使用 seq_id1 // 前边SSSSS*记录 df = df.join(df1, df.col("group_key") === df1.col("group_key1") // 通过group_key 关联 && df1.col("seq_id1") < df.col("seq_id") // seq_id 前边的 && df1.col("type1").notEqual("0"), // 通过type1来选出 SSSSS* 记录 "left") .select(df.col("*"), last(df1.col("value1")).over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)).as("pre_value") ).distinct() // 本身有值的使用旧的 df = df.withColumn("pre_value", when(df.col("value").notEqual("0"), df.col("value")) .otherwise(df.col("pre_value"))) // 后边SSSSS*记录 df = df.join(df1, df.col("group_key") === df1.col("group_key1") && df1.col("seq_id1") > df.col("seq_id") seq_id 后边的 && df1.col("type1").notEqual("0"), "left") .select(df.col("*"), first(df1.col("value1")).over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)).as("next_value") ).distinct() df = df.withColumn("next_value", when(df.col("value").notEqual("0"), df.col("value")) .otherwise(df.col("next_value"))) df.orderBy("seq_id") .show(1000, false)
结果为: