一个比特流类似的一个广播实现。
它的机制如下:
Driver把序列化的对象分成小块,并且存在driver的BlockManager之中,并且汇报给BlockManagerMaster.
在每一个执行器中,执行器首先从他自己的BlockManager中读到对象。如果对象不存在,则从BlockManagerMaster获取对象的地址,这时是driver,然后从该地址中读出对象,
放到本地的BlockManager中,然后再告诉BlockManagerMaster 此BlockManager中有该对象的一个副本。这时其它的执行器从BlockManagerMaster获取该对象的地址时,地址就是增加之后的地址列表,这时随机找一个进行读。
通过重复上述步骤,BlockManagerMaster返回的地址越多,则从driver读取该对象的可能性就越小,所以可以减轻driver的压力。
/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
*
* The mechanism is as follows:
*
* The driver divides the serialized object into small chunks and
* stores those chunks in the BlockManager of the driver.
*
* On each executor, the executor first attempts to fetch the object from its BlockManager. If
* it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
* other executors if available. Once it gets the chunks, it puts the chunks in its own
* BlockManager, ready for other executors to fetch from.
*
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor).
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
* @param obj object to broadcast
* @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
* Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
* which builds this value by reading blocks from the driver and/or other executors.
*
* On the driver, if the value is required, it is read lazily from the block manager.
*/
@transient private lazy val _value: T = readBroadcastBlock()
/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _
private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
}
setConf(SparkEnv.get.conf)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks(obj)
override protected def getValue() = {
_value
}
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
*
* @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(value: T): Int = {
import StorageLevel._
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
val blockManager = SparkEnv.get.blockManager
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
blocks.length
}
/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[ChunkedByteBuffer] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
val blocks = new Array[ChunkedByteBuffer](numBlocks)
val bm = SparkEnv.get.blockManager
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
logDebug(s"Reading piece $pieceId of $broadcastId")
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = b
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
}
blocks
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
override protected def doUnpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
override protected def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
/** Used by the JVM when serializing this object. */
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
val blockManager = SparkEnv.get.blockManager
blockManager.getLocalValues(broadcastId).map(_.data.next()) match {
case Some(x) =>
releaseLock(broadcastId)
x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks().flatMap(_.getChunks())
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
val obj = TorrentBroadcast.unBlockifyObject[T](
blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
}
}
}
/**
* If running in a task, register the given block's locks for release upon task completion.
* Otherwise, if not running in a task then immediately release the lock.
*/
private def releaseLock(blockId: BlockId): Unit = {
val blockManager = SparkEnv.get.blockManager
Option(TaskContext.get()) match {
case Some(taskContext) =>
taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId))
case None =>
// This should only happen on the driver, where broadcast variables may be accessed
// outside of running tasks (e.g. when computing rdd.partitions()). In order to allow
// broadcast variables to be garbage collected we need to free the reference here
// which is slightly unsafe but is technically okay because broadcast variables aren't
// stored off-heap.
blockManager.releaseLock(blockId)
}
}
}
private object TorrentBroadcast extends Logging {
def blockifyObject[T: ClassTag](
obj: T,
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
Utils.tryWithSafeFinally {
serOut.writeObject[T](obj)
} {
serOut.close()
}
cbbos.toChunkedByteBuffer.getChunks()
}
def unBlockifyObject[T: ClassTag](
blocks: Array[ByteBuffer],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = Utils.tryWithSafeFinally {
serIn.readObject[T]()
} {
serIn.close()
}
obj
}
/**
* Remove all persisted blocks associated with this torrent broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}