Spark DataFrame groupBy并获取其它列的聚合值

代码实现:

object test {
  def main(args: Array[String]): Unit = {
    val session = SparkSession
      .builder()
      .appName(this.getClass.getSimpleName).master("local")
      .getOrCreate()

    val df = session.createDataFrame(Seq(
      ("trace1", "src1", "tgt1", "1.0"),
      ("trace2", "src2", "tgt2", "1.0"),
      ("trace3", "src3", "tgt3", "1.0"),
      ("trace2", "src4", "tgt4", "1.0"),
      ("trace3", "src5", "tgt5", "1.0")
    )).toDF("trace_id", "source", "target", "predict")


    val toSeq = udf((b: String) => Seq(b))

    val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten)

    var df1: DataFrame = df
    val features = Seq("source", "target", "predict")
    val featuresToList = features.map(x => flatten(collect_list(x)).alias(x))
    for (colName <- features){
      df1 = df1.withColumn(colName, toSeq(col(colName)))
    }

    df1 = df1.groupBy("trace_id").agg(featuresToList.head, featuresToList.last)

    df1.show()
  }
}

输出格式如下:

    +--------+------------+------------+----------+
    |trace_id|      source|      target|   predict|
    +--------+------------+------------+----------+
    |  trace2|[src2, src4]|[tgt2, tgt4]|[1.0, 1.0]|
    |  trace3|[src3, src5]|[tgt3, tgt5]|[1.0, 1.0]|
    |  trace1|      [src1]|      [tgt1]|     [1.0]|
    +--------+------------+------------+----------+


改进方法:不用将每列数据更改为 Seq():

    val aggColumns = Seq("source", "target", mi.param.PREDICT_COLUMN).map(x => collect_list(x).alias(x))
    val aggDF = mi.df.groupBy("trace_id").agg(aggColumns.head, aggColumns: _*)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值