Spark rdd之mappartition妙用

13 篇文章 0 订阅

mapPartitions(func)源码

def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],preservesPartitioning: Boolean = false): RDD[U] = withScope {
    val cleanedF = sc.clean(f)
    new MapPartitionsRDD(this,(_: TaskContext, _: Int, iter: Iterator[T]) => cleanedF(iter),preservesPartitioning)
  }

类似于 Map 算子,但是不是基于每一条数据,而是基于一个 partition 来计算的,func 将接受一个迭代器,可以从迭代器中获取每一条数据进行操作,返回一个迭代器。形成一个新的 RDD。

该算子一般用于优化 Map 算子,如下面这个例子:

sc.parallelize(Seq(1, 2, 3, 4, 5),1)
.mapPartitions(iter => {
    var res = List[Int]()
    //创建 mysql 客户端
    println("连接数据库")
    while (iter.hasNext) {
        val next = iter.next()
        println("向数据库写入数据:" + next)
        res = res :+ next
    }
    res.toIterator
  })
  .foreach(println)

输出如下,我们可以发现,通过一次连接我们就将一个 partition的数据都写入了数据库,
如果使用的是 Map 算子,那么每写入一条数据都需要一次数据库连接,很明显是不划算的

#一次加载一个分区的数据
连接数据库
向数据库写入数据:1
向数据库写入数据:2
向数据库写入数据:3
向数据库写入数据:4
向数据库写入数据:5
1
2
3
4
5

上面的写法并非最优写法,我们可以这样写:

 sc.parallelize(Seq(1, 2, 3, 4, 5),1)
  .mapPartitions(iter => {
    var res = List[Int]()
    println("连接数据库")
    iter.map(next=>{
      println("向数据库写入数据:" + next)
      next
    })
  })
  .foreach(println)

输出如下,其中的差异你可以细细体会,
不但代码更简单, 而且可以防止partition数据过大导致的 OOM 等问题:

#加载一条数据 写一条数据
连接数据库
向数据库写入数据:1
1
向数据库写入数据:2
2
向数据库写入数据:3
3
向数据库写入数据:4
4
向数据库写入数据:5
5

不过这种写法无法关闭数据库,更好的是自定义一个迭代器。迭代器一对一

case class CustomIterator22(iterator: Iterator[Int]) extends Iterator[Int] {
  println("开启数据库")

  override def hasNext: Boolean = {
    val hasNext: Boolean = iterator.hasNext
    if (!hasNext) {
      println("关闭数据库")
    }
    hasNext
  }

  override def next(): Int = {
    val next: Int = iterator.next()
    println("写入数据"+next)
    next
  }
}

main{
		//优化2:自定义迭代器
    rdd1.mapPartitions(iter=>{CustomIterator22(iter)}).foreach(println)
}
开启数据库
写入数据1
1
写入数据2
2
写入数据3
3
写入数据4
4
写入数据5
5
关闭数据库

迭代器一对多

mappartition + flatMap + iterator

def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
    val sc: SparkContext = SparkContext.getOrCreate(conf)

    val list2: RDD[List[Int]] = sc.parallelize(List(List(1,2,3,4,5,6),List(2,3,4)))

    list2.mapPartitions(f=>CustomIter(f))
      .foreach(println)
  }

  //自定义迭代器 一对多
  case class CustomIter(iter:Iterator[List[Int]]) extends Iterator[Int] {

    //自定义处理函数
    def myF(list: List[Int]): Iterator[Int] ={
      list.iterator
    }

    //创建空的迭代器
    private var cur: Iterator[Int] = Iterator.empty

    override def hasNext: Boolean = {
      cur.hasNext || iter.hasNext && {
        cur = myF(iter.next())
        hasNext
      }
    }

    override def next(): Int = {
      (if (hasNext) cur else Iterator.empty).next()
    }
  }
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值