Spark - PartitionPruningRDD 详解

一.引言

查看 RangePartition 的源码时发现内部用到了 PartitionPruningRDD,翻译为分区修剪 RDD,下面简单介绍一下 PartitionPruningRDD 的使用。

二.PartitionPruningRDD 简介

1.PartitionPruningRDD 释义

接触 Spark 这么久第一次见到这个全新的 RDD 形式,立马去 API 查看了下:

A.释义

-> 简单翻译一下:

用于修剪RDD分区的RDD,这样我们就可以避免在上启动任务所有分区。示例用例:如果我们知道RDD是按范围划分的,执行DAG的键上有一个过滤器,我们可以避免启动任务在没有覆盖键的范围的分区上。

-> 白话解释一下:

PartitionPruningRDD 接收两个参数,一个为原始 RDD,另一个为 PartitionFilterFunction,前者的每一个 partition 提供一些参数依据,例如 partition.id 对应的分区 id,PartitionFilterFunction 则负责根据 partition 的参数进行过滤,保留过滤结果为 true 的结果,和正常 Array 的 filter 类似。

B.用途

为什么要用 PartitionPruningRDD:

在已知 RDD 按范围划分 且 运行任务无需启动所有分区的情况下用到 PartitionPruningRDD,所以主要点有两个:

-> 已知 RDD 按范围划分

这个涉及到 RDD 下的多个 partitions,一般只有自定义 Partitioner 函数得到的 partition 或者是顺序保存的文件得到的 partition 才是已知范围划分的,例如自定义 Partitioner 将 10000 个数据分为5个 partition,partition-0 对应 0-1999 的数据,partition-1 对应 2000-3999 的数据,以此类推,这样就可以成为我们已知 RDD 按范围划分。

-> 任务运行无需启动所有分区

这个就基于具体场景而定了,例如 RangePartitioner 生成分区边界时,会对样本数量超过均值的 partition 进行 resample 重采样,这时候只需要拿出来不满足条件的 partition 重新操作即可,满足的 partition 则略过,这里其实也要求用户对 RDD 的范围划分或者基础信息是已知的,否则无法定义好 PartitionFilterFunction 过滤函数。

2.代码浅析

PartitionPruningRDD 继承了 new PruneDependency(prev, partitionFilterFunc)),其中 prev 为 RDD[T],partitionFilterFunc 为 partition 的过滤函数:

释义表示 PartitionPruningRDD 与其父对象之间的依赖关系。在这种情况下,子 RDD 包含父 RDD 的分区子集。这里 partitions 的生成方式可以很清晰的看到 PartitionPruningRDD 的由来方式:

  @transient
  val partitions: Array[Partition] = rdd.partitions
    .filter(s => partitionFilterFunc(s.index)).zipWithIndex
    .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }

rdd.partitions 获取 Array[Partition] 即当前 rdd 下所有 parititon,s.index 可以理解为 partitionId,可以通过 rdd.mapPartitionsWithIndex 获取 partition 的 id 也可以通过环境变量 TaskContext.getPartitionId() 获取对应 partitionId,剩下的则是通过 filter + PartitionFilterFunction 进行过滤,将结果为 true 的 partition 保留并生成 new PartitionPruningRDDPartition。

所以可以近似理解为这样的过滤方法:

    def filterFunc(id: Int): Boolean = {
      id % 2 == 0
    }  
    
    val purningRdd = rdd.partitions.filter(partition => {
      val partitionId = TaskContext.getPartitionId()
      filterFunc(partitionId)
    })

不过这里只能得到对应的 Partition 类,而不是对应的 partrition: Iterator[T],所以想要筛选一部分 partition 还是要采用 PartitionPruningRDD。

三.PartitionPruningRDD 使用与验证 

1.PartitionPruningRDD 使用

A.构造 Common RDD

    val conf = new SparkConf().setAppName("PartitionTest").setMaster("local[5]")
    val spark = SparkSession
      .builder
      .config(conf)
      .getOrCreate()
    val sc = spark.sparkContext
    sc.setLogLevel("error")

    val oriPartitions = 5
    val randomArray = (0 until 10000).toArray
    val info = java.util.Arrays.asList(randomArray: _*).toArray().map(_.asInstanceOf[Int]).zipWithIndex
    val rdd = sc.parallelize(info).repartition(oriPartitions)

这里使用 Local[5] + repartition(5) 得到 5个 partition,每个 partition 包含 2000 个元素:

    rdd.foreachPartition(partition => {
      val taskId = TaskContext.getPartitionId()
      println(s"TaskId: $taskId PartitionSize: ${partition.size}")
    })

 rdd 下共有 5 个 partition,分别对应 ID 0-4,每个 partition 下有 2000 个元素。

B.构造 PartitionPruningRDD

    val imbalancedPartitions = mutable.Set.empty[Int]
    imbalancedPartitions += 0

    val pruningRDD = new PartitionPruningRDD(rdd, imbalancedPartitions.contains)

PartitionFilterFunction 为 Set.contains ,这里过滤条件为只保留 TaskID 为 0 的 partition,即5选1,prev RDD 为前面构造的包含 10000 个元素的 RDD[(Int, Int)]。

2.PartitionPruningRDD 验证

为了验证 PartitionPruningRDD 是否只保留了 PartitionId=0 的数据,我们通过把 RDD 数据 collect 为 Set 的方法进行验证:

A.CommonRddSet

    val commonRDDSet = rdd.mapPartitionsWithIndex { (idx, iter) =>
      val info = new ArrayBuffer[Int]()
      if (idx == 0) {
        iter.foreach(x => info.append(x._1))
      }
      Iterator(info.toArray)
    }.collect().flatten.toSet

使用 mapPartitionsWithIndex 获取对应 partitionId - idx 与对应迭代器 iter - Iterator[(Int, Int)],这里将 partitionId = 0 的 partition 的 (Int,Int)._1 都保留下来并 collect 为 Set。

B.PruningRDDSet

    val imbalancedPartitions = mutable.Set.empty[Int]
    imbalancedPartitions += 0

    val pruningRDDSet = new PartitionPruningRDD(rdd, 
                imbalancedPartitions.contains)
                .collect()
                .map(_._1)
                .toSet

FilterFunc 采用 Set(0).contains,只保留 partitionId=0 的 partition,并将对应元素的 _._1 collect 得到 Set。

C.验证

    println(commonRDDSet.size, pruningRDDSet.size)
    println((commonRDDSet -- pruningRDDSet).mkString(","))
    println((pruningRDDSet -- commonRDDSet).mkString(","))
    println(pruningRDDSet.intersect(commonRDDSet).size)

分别判断两个 Set 的 size,以及双向的差集,最后判断交集。

两个 Set 均为 2000 元素,且双向差集均为空,最后交集元素大小为 2000,所以两个 Set 完全一致,即 PartitionPruningRDD 保存了 RDD 对应 Id 的 partition,非常的奈斯。

四.总结

使用 Spark 很久,不管是工作代码还是浏览博客都没有注意到 PartitionPruningRDD 这个 RDD,也是看 RangePartition 的源码时发现了有这个简易过滤 partition 的 RDD 方法,整体来说 PartitionPruningRDD 应用场景比较少,因为使用 PartitionPruningRDD 的前提是 RDD 按范围划分,这个在大数据场景下经常无法得到这么规整的数据集,其次使用 mapPartition 和 foreachPartition 时返回 null 再过滤 -> filter(_ != null) 也可以快速实现过滤某个 partition 下的所有数据,从而导致工作学习中也很少见到该 RDD。不过 RangePartition 生成 boundary 就是其应用场景之一,这里先做一个铺垫,后续介绍 RangePartition 时可以借鉴 PartitionPruningRDD 的使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BIT_666

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值