代码实现:
object test {
def main(args: Array[String]): Unit = {
val session = SparkSession
.builder()
.appName(this.getClass.getSimpleName).master("local")
.getOrCreate()
val df = session.createDataFrame(Seq(
("trace1", "src1", "tgt1", "1.0"),
("trace2", "src2", "tgt2", "1.0"),
("trace3", "src3", "tgt3", "1.0"),
("trace2", "src4", "tgt4", "1.0"),
("trace3", "src5", "tgt5", "1.0")
)).toDF("trace_id", "source", "target", "predict")
val toSeq = udf((b: String) => Seq(b))
val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten)
var df1: DataFrame = df
val features = Seq("source", "target", "predict")
val featuresToList = features.map(x => flatten(collect_list(x)).alias(x))
for (colName <- features){
df1 = df1.withColumn(colName, toSeq(col(colName)))
}
df1 = df1.groupBy("trace_id").agg(featuresToList.head, featuresToList.last)
df1.show()
}
}
输出格式如下:
+--------+------------+------------+----------+
|trace_id| source| target| predict|
+--------+------------+------------+----------+
| trace2|[src2, src4]|[tgt2, tgt4]|[1.0, 1.0]|
| trace3|[src3, src5]|[tgt3, tgt5]|[1.0, 1.0]|
| trace1| [src1]| [tgt1]| [1.0]|
+--------+------------+------------+----------+
改进方法:不用将每列数据更改为 Seq()
:
val aggColumns = Seq("source", "target", mi.param.PREDICT_COLUMN).map(x => collect_list(x).alias(x))
val aggDF = mi.df.groupBy("trace_id").agg(aggColumns.head, aggColumns: _*)