Spark源码分析(八):ShuffleWriter

ShuffleWriter

ShuffleWriter一共分成三种:ByPassMergeSortShuffleWriter,SortShuffleWriter和UnsafeShuffleWriter
首先通过源码分析一下运行时如何选择ShuffleWriter
看一下ShuffleMapTask的runTask()

// 获取shuffleMnager
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
// 首先调用了rdd的iterator并且传入了当前task要处理哪个partition
// 核心地逻辑,就在rdd的iteretor中
// 返回的数据,都是通过ShuffleWriter,经过HashPartitioner进行分区后
// 写入自己对应的分区bucket中
// 默认的writer是HashWriter
// 在spark2.0中已经移除了HashWriter,使用了SortWriter
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
// 最后,返回结果MapStatus,封装了ShuffleMapTask计算后的数据
// 存储在哪里,其实就是BlockManager的相关的信息
// BlockManager,是Spark的内存,数据,磁盘管理工具
writer.stop(success = true).get

上面的代码主要作用就是获取ShuffleWriter,然后将rdd的运行结果通过ShuffleWriter写入到由BlockManager管理的本地磁盘
获取ShuffleWriter的关键代码是

writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)

接着看一下dep.shuffleHandle是怎么得到的

val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
    shuffleId, _rdd.partitions.length, this)
override def registerShuffle[K, V, C](
      shuffleId: Int,
      numMaps: Int,
      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
 if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
   /*
   * 如果当前存在的partition数量比spark.shuffle.sort.bypassMergeThreshold少并且不需要map端的合并操作
   * 那么我们直接写入numPartitions个文件。并且在最后合并它们
   * 这可以避免序列化和反序列化
   * */
   new BypassMergeSortShuffleHandle[K, V](
     shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
 } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
   // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
   new SerializedShuffleHandle[K, V](
     shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
 } else {
   // Otherwise, buffer map outputs in a deserialized form:
   new BaseShuffleHandle(shuffleId, numMaps, dependency)
 }
}
// 判断是否使用ByPassMergeSort
// 使用条件:
// (1)不需要进行map端的聚合
// (2)partition数量小于spark.shuffle.sort.bypassMergeThreshold,默认是200
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
  // We cannot bypass sorting if we need to do map-side aggregation.
  if (dep.mapSideCombine) {
    false
  } else {
    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
    dep.partitioner.numPartitions <= bypassMergeThreshold
  }
}
// 使用unsafeShuffleWriter的条件
  /*
  * (1)序列化格式需要支持重定位
  * (2)不需要map端join
  * (3)partition数量大于某个阈值
  * */
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
  val shufId = dependency.shuffleId
  val numPartitions = dependency.partitioner.numPartitions
  if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
    log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
      s"${dependency.serializer.getClass.getName}, does not support object relocation")
    false
  } else if (dependency.mapSideCombine) {
    log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " +
      s"map-side aggregation")
    false
  } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
    log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
      s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
    false
  } else {
    log.debug(s"Can use serialized shuffle for shuffle $shufId")
    true
  }
}

如果上面两个条件都不成立,那么使用最基本的SortShuffleWriter
handler和writer的对应情况:
BypassMergeSortShuffleHandle对应BypassMergeSortShuffleWriter
SerializedShuffleHandle对应UnsafeShuffleWriter
BaseShuffleHandle对应SortShuffleWriter
在接下来的博客中,将分别介绍这三种ShuffleWriter

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值