spark broadcast的TorrentBroadcast实现

在spark中,默认采用的broadcast的方式Torrent方式,其实现方式也是TorrentBroadcast类,当通过spark上下文调用broadcast广播某数据时,将会生成唯一的broadcastid用于区分该广播变量。

 

在TorrentBroadcast的构造过程中,将会通过writeBlocks()方法将所需要广播的数据切分并序列化。具体的切分逻辑实现在了其伴生对象的blockfyObject()方法中。

def blockifyObject[T: ClassTag](
    obj: T,
    blockSize: Int,
    serializer: Serializer,
    compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
  val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
  val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
  val ser = serializer.newInstance()
  val serOut = ser.serializeStream(out)
  Utils.tryWithSafeFinally {
    serOut.writeObject[T](obj)
  } {
    serOut.close()
  }
  cbbos.toChunkedByteBuffer.getChunks()
}

在这里看到,实际的广播变量序列化对象将在这里被分块并序列化,实际的序列化在ChunkedByteBufferOutputStream中。

ChunkedByteBufferOutputStream中存在一个由ByteBuffer构成的ArrayBuffer,用来保存具体序列化后的广播变量。

override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
  require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
  var written = 0
  while (written < len) {
    allocateNewChunkIfNeeded()
    val thisBatch = math.min(chunkSize - position, len - written)
    chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
    written += thisBatch
    position += thisBatch
  }
  _size += len
}

@inline
private def allocateNewChunkIfNeeded(): Unit = {
  if (position == chunkSize) {
    chunks += allocator(chunkSize)
    lastChunkIndex += 1
    position = 0
  }
}

以上是ChunkedByteBufferOutputStream的write()方法,每次都会按照切分大小(默认4M)申请一个切分大小的ByteBuffer,写入相应大小的数据,直到当前ByteBuffer耗尽再申请写一个分块,或者直到需要序列化的广播变量耗尽。

 

接下来看到TorrentBroadcast的writeBlocks()方法。

private def writeBlocks(value: T): Int = {
  import StorageLevel._
  // Store a copy of the broadcast variable in the driver so that tasks run on the driver
  // do not create a duplicate copy of the broadcast variable's value.
  val blockManager = SparkEnv.get.blockManager
  if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
    throw new SparkException(s"Failed to store $broadcastId in BlockManager")
  }
  val blocks =
    TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
  if (checksumEnabled) {
    checksums = new Array[Int](blocks.length)
  }
  blocks.zipWithIndex.foreach { case (block, i) =>
    if (checksumEnabled) {
      checksums(i) = calcChecksum(block)
    }
    val pieceId = BroadcastBlockId(id, "piece" + i)
    val bytes = new ChunkedByteBuffer(block.duplicate())
    if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
      throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
    }
  }
  blocks.length
}

在已经得到序列化后的切分数据之后,遍历其中的ByteArray生成该广播数据的切片id并压缩至BlockManager中持久化,这份广播数据也将再次被保存在driver上,供executor使用。

 

在具体的executor中只有在真正需要使用这份广播变量时,才会通过readBroadcastBlock()方法惰性加载。

@transient private lazy val _value: T = readBroadcastBlock()

在readBroadcastBlock()方法中,将会判断本地是否已经存在这份广播变量的值,如果已经存在,则可以直接在本地获取,因此,如果是同driver运行于在一台上的executor将不需要从远程获取该广播变量。

而如果没有,则需要通过readBlocks()方法获取。

private def readBlocks(): Array[BlockData] = {
  // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
  // to the driver, so other executors can pull these chunks from this executor as well.
  val blocks = new Array[BlockData](numBlocks)
  val bm = SparkEnv.get.blockManager

  for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
    val pieceId = BroadcastBlockId(id, "piece" + pid)
    logDebug(s"Reading piece $pieceId of $broadcastId")
    // First try getLocalBytes because there is a chance that previous attempts to fetch the
    // broadcast blocks have already fetched some of the blocks. In that case, some blocks
    // would be available locally (on this executor).
    bm.getLocalBytes(pieceId) match {
      case Some(block) =>
        blocks(pid) = block
        releaseLock(pieceId)
      case None =>
        bm.getRemoteBytes(pieceId) match {
          case Some(b) =>
            if (checksumEnabled) {
              val sum = calcChecksum(b.chunks(0))
              if (sum != checksums(pid)) {
                throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
                  s" $sum != ${checksums(pid)}")
              }
            }
            // We found the block from remote executors/driver's BlockManager, so put the block
            // in this executor's BlockManager.
            if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
              throw new SparkException(
                s"Failed to store $pieceId of $broadcastId in local BlockManager")
            }
            blocks(pid) = new ByteBufferBlockData(b, true)
          case None =>
            throw new SparkException(s"Failed to get $pieceId of $broadcastId")
        }
    }
  }
  blocks
}

此处,根据需要得到的广播变量的分片数量,将会不断随机乱序依次获取其切片数据。

如果本地不存在需要的分片数据,将会通过BlockManager的getRemoteBytes()方法远程获取数据。

在BlockManager中,每个广播变量的分片数据都记录着其存在于哪个地址上。别的地址需要远程获取该分片上的数据上时,会从BlockManager的master中获取该数据所在的所有地址,依次尝试获取目标分片的数据。

 

在从getRemoteBytes()得到对应的数据之后,将会持久化在本地,同时更新BlockManager中该分片的地址,以便接下来别的executor在需要获取该分片时,不用再从master中下载,而是可以直接在此处下载得到该份分片数据。

 

在完成所有的数据下载之后,依照分片编号排序依次反序列化,可得到需要的广播数据。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值