今天看书看到了mapPartitions
,体会了一下分区操作。
package com.cnnc.sparkLearning
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object test {
def main(args: Array[String]): Unit = {
case class Person(name:String,sex:String)
val conf = new SparkConf().setMaster("local").setAppName("test")
val sc = new SparkContext(conf)
val rdd: RDD[Int] = sc.parallelize(List(1,2,3,4,4,5),2)
def partitionCrr(s:scala.Iterator[Int]) ={
val temp = Array(0,0)
while (s.hasNext){
val i: Int = s.next()
temp(0)+=i
temp(1)+=1
}
Iterator(Tuple2(temp(0),temp(1)))//要求返回对象的迭代器
}
val tuple: (Int, Int) = rdd.mapPartitions(partitionCrr).reduce((a,b)=>(a._1+b._1,a._2+b._2))
println(tuple._1/tuple._2.toDouble)
}
}
以往的求和我们会将数字转成一个二元组,第一位放数字,第二位放1(便于累加求数字个数总和),但是这样操作会对每一个数字创建一个二元组。
然而mapPartitions
就是为了避免创建过多的对象,上述代码中,在每个分区中只创建了一个Array(0,0)
,相比之前创建6个二元组要节省一些内存。