源码
// 通过将函数应用于此迭代器生成的所有值并连接结果来创建一个新的迭代器。
//参数:f - 应用于每个元素的函数
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
private var cur: Iterator[B] = empty
private def nextCur(): Unit = { cur = null ; cur = f(self.next()).toIterator }
def hasNext: Boolean = {
// Equivalent to cur.hasNext || self.hasNext && { nextCur(); hasNext }
// but slightly shorter bytecode (better JVM inlining!)
while (!cur.hasNext) {
if (!self.hasNext) return false
nextCur()
}
true
}
def next(): B = (if (hasNext) cur else empty).next()
}
flatMap其实就是将RDD里的每一个元素执行自定义函数f,这时这个元素的结果转换成iterator,最后将这些再拼接成一个新的RDD,也可以理解成原本的每个元素由横向执行函数f后再变为纵向。next一直在回调,当RDD内没有元素为止。
源码改写1:
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
private var cur: Iterator[B] = empty
private def nextCur() {
cur = f(self.next()).toIterator
}
def hasNext: Boolean = {
// Equivalent to cur.hasNext || self.hasNext && { nextCur(); hasNext }
// but slightly shorter bytecode (better JVM inlining!)
while (!cur.hasNext) { //如果当前迭代器没有值
if (!self.hasNext) return false //如果自身迭代器没有值,返回false
nextCur() //如果自身迭代器有值,调用f函数 把self的一个值放入cur中
}
true
}
def next(): B = (if (hasNext) cur else empty).next()
}
// false || true && true = true
// false || true && false = false
源码改写2:
// f 函数是 传入每一条数据都需要返回一个迭代器
// 也就是说一条记录可以返回多个值
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
//定义当前的迭代器是空的
private var cur: Iterator[B] = empty
//这是源码,为了方便理解,我稍微改写了下
//def hasNext: Boolean =
//cur.hasNext || self.hasNext && {
// cur = f(self.next).toIterator;
// hasNext
//}
def hasNext: Boolean ={
if(cur.hasNext){
//如果当前迭代器还有值,
//则返回true
return true
}
if(self.hasNext){
//如果cur已经没有值了
//但是本身的迭代器还有值
//则我们把本身迭代器的一个值拿出来
//通过 f函数 构造一个迭代器放到当前的迭代器
cur = f(self.next).toIterator;
//再递归一次本函数来看是否还有值
return hasNext
}
}
//这个就没什么好说的了
def next(): B = (if (hasNext) cur else empty).next()
}
flatMap遍历:
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName(this.getClass.getSimpleName).setMaster("local[*]")
val sc: SparkContext = SparkContext.getOrCreate(conf)
val list2: RDD[List[Int]] = sc.parallelize(List(List(1,2,3,4,5,6),List(2,3,4)))
//todo flatMap遍历
//1
list2.flatMap(x=>x).foreach(println)
//2
list2.flatMap(x=>{
println(x)
x
}).foreach(f=>println("aa"))
//3
println("yield------------")
list2.flatMap(x=>{
println(x)
for (a<- x)
yield a
}).foreach(println)
}