实例
首先先来看一看broadcast的使用代码:
val factor = List[Int](1,2,3);
val factorBroadcast = sc.broadcast(factor)
val nums = Array(1,2,3,4,5,6,7,8,9)
val numsRdd = sc.parallelize(nums,3)
val list = new ListBuffer[List[Int]]()
val resRdd = numsRdd.mapPartitions(ite =>{
while (ite.hasNext){
list+=ite.next()::(factorBroadcast.value)
}
list.iterator
})
resRdd.foreach(res => println(res))
/**结果:
List(1, 1, 2, 3)
List(2, 1, 2, 3)
List(3, 1, 2, 3)
List(4, 1, 2, 3)
List(5, 1, 2, 3)
List(6, 1, 2, 3)
List(7, 1, 2, 3)
List(8, 1, 2, 3)
List(9, 1, 2, 3)
*/
首先生成了一个集合变量,把这个变量通过sparkContext的broadcast函数进行广播,最后在rdd的每一个partition迭代时,使用这个广播变量。
源码分析
接下来看看广播变量的生成与数据的读取实现部分:
广播变量的生成
(1)SparkContext.broadcast( )
/**
* 向集群广播一个只读变量,返回在分布式函数中读取它的对象。
* @return 在每个Executor上缓存一个只读变量
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
//不能直接广播RDDs;代替,调用collect()并传播结果。
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
//通过broadcastManager中的newBroadcast函数来进行广播.
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id +