Spark CmbineByKey之自定义key 源码解析
习惯了MapReduce,对于自定义对象作为Key的情况太常见了,这里就对于Spark的CmbineByKey来进行分析,怎样实现自定义Key,无论书MR还是CmbineByKey的核心思想都是相同的key为一组.
简要分析AppendOnlyMap
AppendOnlyMap
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
//key和value在同一个数组,通过hash值获取position来确定元素
private var data = new Array[AnyRef](2 * capacity)
//该方法对于combineByKey来说,每一项都会调用该方法,key是只新增不会修改,value会进行update
//key的新增,在第一次出现key的时候会计算出pos,并插入
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = updateFunc(haveNullValue, nullValue)
haveNullValue = true
return nullValue
}
//这里通过key来获取pos,所以这里我们要重写我们key的hashcode
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
//不存在该key,新增
if (curKey.eq(null)) {
// updateFunc 就是createCombiner或者mergeValue方法
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
//如果存在,还需要判断equals方法,hashcode可能相同 ,所以key的equals也要重写
} else if (k.eq(curKey) || k.equals(curKey)) {
//这里需要更新新值
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
null.asInstanceOf[V] // Never reached but needed to keep compiler happy
}
/** Iterator method from Iterable */
override def iterator: Iterator[(K, V)] = {
assert(!destroyed, destructionMessage)
new Iterator[(K, V)] {
var pos = -1
/** Get the next value we should return from next(), or null if we're finished iterating */
def nextValue(): (K, V) = {
if (pos == -1) { // Treat position -1 as looking at the null value
if (haveNullValue) {
return (null.asInstanceOf[K], nullValue)
}
pos += 1
}
while (pos < capacity) {
if (!data(2 * pos).eq(null)) {
return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
}
pos += 1
}
null
}
override def hasNext: Boolean = nextValue() != null
override def next(): (K, V) = {
val value = nextValue()
if (value == null) {
throw new NoSuchElementException("End of iterator")
}
pos += 1
value
}
}
}
override def size: Int = curSize
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
if (curSize > growThreshold) {
growTable()
}
}
//ExternalSorter的insertAll方法定义update方法
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
结论
自定义Key必须要重写 hashCode和equals方法,来进行key的比较,只有这样才能保证相同的key被分到一组
测试
package com.yzz.spark
import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
/**
*
* @time 2020/7/11 21:46
* @author yzz
* @E-mail yzzstyle@163.com
* @since 0.0.1
*/
object TestCombineByKey {
def main(args: Array[String]): Unit = {
val config = new SparkConf()
.setAppName("WordCount")
.setMaster("local")
.setJars(List("C:\\work\\study\\spark_2020\\wordCount\\target\\wordCount-1.0-SNAPSHOT-jar-with-dependencies.jar"))
val sc = new SparkContext(config)
val p = new SchoolPartition(3)
val rdd = sc.parallelize(List((Student("学校A", "学生A", 18),1), (Student("学校A", "学生B", 18),1), (Student("学校B","学生C" , 18),1), (Student("学校C","学生D", 18),1)))
val rdd1 = rdd.partitionBy(p)
val rdd2 = rdd1.combineByKey((x: Int) => x, (x: Int, y: Int) => x + y, (x: Int, y: Int) => x + y,p)
rdd2.collect()
}
class SchoolPartition(val partitions: Int) extends Partitioner with Serializable {
override def numPartitions: Int = partitions
override def getPartition(key: Any): Int = key match {
case Student(school, _, _) =>
val code = school.hashCode % partitions
code + (if (code < 0) partitions else 0)
case _ => 0
}
/**
* 必须重写,这里决定了RDD之间的分区是否一致,直接影响Stage的划分,从而影响 shuffle
*
* @param obj
* @return
*/
override def equals(obj: Any): Boolean = obj match {
case p: SchoolPartition =>
p.partitions == partitions
case _ => false
}
}
case class Student(val school: String, val name: String, val age: Int) extends Serializable with Ordered[Student] {
/**
* 按照 school name age 来排序
*
* @param that
* @return
*/
override def compare(that: Student): Int = {
val c1 = school.compareTo(that.school)
if (c1 == 0) {
val c2 = name.compareTo(that.name)
if (c2 == 0) age - that.age else c2
} else c1
}
/**
* 相同的学校为一组
*
* @param other
* @return
*/
override def equals(other: Any): Boolean = other match {
case that: Student =>
that.school.equals(school)
case _ => false
}
override def hashCode(): Int = {
val state = Seq(school)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}
override def toString: String = s"[$school,$name,$age]"
object Student {
def apply(school: String, name: String, age: Int): Student = new Student(school, name, age)
}
}
}
总结
需要自定义key,对于需要进行分组的操作都需要重写 hashCode和equals方法,对于排序操作,都需要通过提供Ordering或者Ordered才能正常对自定义Key进行排序。