Learning Spark笔记13-Broadcast Variables(传播变量)

传播变量(Broadcast Variables)


Spark第二种类型的共享变量:传播变量,它允许程序有效地向一个或多个Spark操作中的所有工作节点发送一个大的只读值。什么时候可以派上用场呢?例如,如果你的应用程序需要发送一个大的,只读的查找表给所有节点,或者在机器学习算法中的一个大的特征向量。


回想一下,Spark会自动将关闭中引用的所有变量发送到工作节点。虽然这很方便,但也可能效率不高。第一,默认的任务启动机制是针对小任务做的优化。第二,实际上你可能在多个并行操作中使用同一个变量,但是Spark会为每个操作单独的发送它。例如,我们写一个Spark程序使用数组通过前缀匹配来查找国家。


Example 6-6. Country lookup in Python
# Look up the locations of the call signs on the
# RDD contactCounts. We load a list of call sign
# prefixes to country code to support this lookup.
signPrefixes = loadCallSignTable()
def processSignCount(sign_count, signPrefixes):
country = lookupCountry(sign_count[0], signPrefixes)
count = sign_count[1]
return (country, count)
countryContactCounts = (contactCounts.map(processSignCount).reduceByKey((lambda x, y: x+ y)))


程序运行,我们可能会有一个很大的表(IP地址会代替呼号),signPrefixes很容易就能达到MB,将数字从主节点发送给每个任务的成本就很高。此外,如果之后我们使用相同的signPrefixes对象(我们可能会运行相同的代码在file2.txt上),它会再一次发送到每个节点上。


我们可以将signPrefixes定义成传播变量来解决这个问题。一个传播变量就是spark.broadcast.Broadcast[T]类型的对象,值的类型是T。我们可以通过传播对象的value获得它的值。这个值只会发送给每个节点一次,使用类似于BitTorrent的机制。


使用传播变量,我们之前得例子像这样:


Example 6-7. Country lookup with Broadcast values in Python
# Look up the locations of the call signs on the
# RDD contactCounts. We load a list of call sign
# prefixes to country code to support this lookup.
signPrefixes = sc.broadcast(loadCallSignTable())
def processSignCount(sign_count, signPrefixes):
 country = lookupCountry(sign_count[0], signPrefixes.value)
 count = sign_count[1]
 return (country, count)
countryContactCounts = (contactCounts
 .map(processSignCount)
 .reduceByKey((lambda x, y: x+ y)))
countryContactCounts.saveAsTextFile(outputDir + "/countries.txt")




Example 6-8. Country lookup with Broadcast values in Scala
// Look up the countries for each call sign for the
// contactCounts RDD. We load an array of call sign
// prefixes to country code to support this lookup.
val signPrefixes = sc.broadcast(loadCallSignTable())
val countryContactCounts = contactCounts.map{case (sign, count) =>
 val country = lookupInArray(sign, signPrefixes.value)
 (country, count)
}.reduceByKey((x, y) => x + y)
countryContactCounts.saveAsTextFile(outputDir + "/countries.txt")


Example 6-9. Country lookup with Broadcast values in Java
// Read in the call sign table
// Look up the countries for each call sign in the
// contactCounts RDD
final Broadcast<String[]> signPrefixes = sc.broadcast(loadCallSignTable());
JavaPairRDD<String, Integer> countryContactCounts = contactCounts.mapToPair(
 new PairFunction<Tuple2<String, Integer>, String, Integer> (){
 public Tuple2<String, Integer> call(Tuple2<String, Integer> callSignCount) {
 String sign = callSignCount._1();
 String country = lookupCountry(sign, callSignInfo.value());
 return new Tuple2(country, callSignCount._2());
 }}).reduceByKey(new SumInts());
countryContactCounts.saveAsTextFile(outputDir + "/countries.txt");


使用传播变量非常简单:


1.通过SparkContext.broadcast来创建Broadcast[T]对象,任何类型都可以只要它是可序列化的。
2.使用value属性来访问值(java中使用value()方法)
3.该变量只会发送给每个节点一次,被当做只读对待(更新不会传播到其他节点)


满足只读需求的最简单的方式是声明一个传播原始值或一个不可变对象的引用。在这种情况下,你只能在驱动代码中改变传播变量的值。然而,有时它可以更简单或者更有效的声明一个广播的可变对象。如果你这样做的话,维护只读条件取决于你。像之前我们调用的前缀表Array[String],我们必须确定在工作节点的上的代码没有像这样的语句:
val theArray = broadcastArray.value; 
theArray(0) = newValue
当在一个工作节点中运行时,该行将在运行代码的工作节点的本地数组的副本中将newValue分配给第一个数组元素。它不会改变broadcastArray.value的内容在任何其他的工作节点上。




优化广播


当我们传播一个很大的值得时候,选择一个小而紧凑的数据序列化的格式是很重要的,发送到网络的时间如果太长那就造成瓶颈。特别是在java序列化的时候,Spark的scala和java api默认的序列化库对于除了原始数组类型以外都很低效。你可以使用spark.serializer的属性来使用不同的序列化类库,或者为你的数据类型实现你自己的序列化(例如,使用java.io.Externalizable接口,或者使用reduce方法为python的类库实现自定义的序列化)


使用每个分区的数据可以避免重新设置每个数据项的设置工作。类似的操作如打开数据连接或者创建一个随机数的生成器的设置步骤避免在每个元素上操作。Spark具有每个分区版本的map和foreach,通过让RDD的每个分区只运行一次代码来帮助降低这些操作的成本。


让我们回到之前呼号的例子,有一个无线电呼号的在线数据库可以查询到联系人的列表。通过使用分区的操作,我们可以分享数据库的连接池来避免建立很多的连接,然后重用JSON解析器。下面的列子,我们使用mapPartitions()函数,它给我们输入RDD的每个分区元素的迭代器,然后返回我们结果的迭代器。


Example 6-10. Shared connection pool in Python
def processCallSigns(signs):
 """Lookup call signs using a connection pool"""
 # Create a connection pool
 http = urllib3.PoolManager()
 # the URL associated with each call sign record
 urls = map(lambda x: "http://73s.com/qsos/%s.json" % x, signs)
 # create the requests (non-blocking)
 requests = map(lambda x: (x, http.request('GET', x)), urls)
 # fetch the results
 result = map(lambda x: (x[0], json.loads(x[1].data)), requests)
 # remove any empty results and return
 return filter(lambda x: x[1] is not None, result)
def fetchCallSigns(input):
 """Fetch call signs"""
 return input.mapPartitions(lambda callSigns : processCallSigns(callSigns))
contactsContactList = fetchCallSigns(validSigns)




Example 6-11. Shared connection pool and JSON parser in Scala
val contactsContactLists = validSigns.distinct().mapPartitions{
 signs =>
 val mapper = createMapper()
 val client = new HttpClient()
 client.start()
 // create http request
 signs.map {sign =>
 createExchangeForSign(sign)
// fetch responses
 }.map{ case (sign, exchange) =>
 (sign, readExchangeCallLog(mapper, exchange))
 }.filter(x => x._2 != null) // Remove empty CallLogs
}


Example 6-12. Shared connection pool and JSON parser in Java
// Use mapPartitions to reuse setup work.
JavaPairRDD<String, CallLog[]> contactsContactLists =
 validCallSigns.mapPartitionsToPair(
 new PairFlatMapFunction<Iterator<String>, String, CallLog[]>() {
 public Iterable<Tuple2<String, CallLog[]>> call(Iterator<String> input) {
 // List for our results.
 ArrayList<Tuple2<String, CallLog[]>> callsignLogs = new ArrayList<>();
 ArrayList<Tuple2<String, ContentExchange>> requests = new ArrayList<>();
 ObjectMapper mapper = createMapper();
 HttpClient client = new HttpClient();
 try {
 client.start();
 while (input.hasNext()) {
 requests.add(createRequestForSign(input.next(), client));
 }
 for (Tuple2<String, ContentExchange> signExchange : requests) {
 callsignLogs.add(fetchResultFromRequest(mapper, signExchange));
 }
 } catch (Exception e) {
 }
 return callsignLogs;
 }});
System.out.println(StringUtils.join(contactsContactLists.collect(), ","));


此外避免设置工作,我们可以使用mapPartitions()避免过多的对象创建。有时我们需要通过一个对象来聚合不同类型的结果。


Example 6-13. Average without mapPartitions() in Python
def combineCtrs(c1, c2):
 return (c1[0] + c2[0], c1[1] + c2[1])
def basicAvg(nums):
 """Compute the average"""
 nums.map(lambda num: (num, 1)).reduce(combineCtrs)
Example 6-14. Average with mapPartitions() in Python
def partitionCtr(nums):
 """Compute sumCounter for partition"""
 sumCount = [0, 0]
 for num in nums:
 sumCount[0] += num
 sumCount[1] += 1
 return [sumCount]
def fastAvg(nums):
 """Compute the avg"""
 sumCount = nums.mapPartitions(partitionCtr).reduce(combineCtrs)
 return sumCount[0] / float(sumCount[1])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

艺菲

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值