Spark CmbineByKey之自定义key 源码解析

Spark CmbineByKey之自定义key 源码解析

习惯了MapReduce,对于自定义对象作为Key的情况太常见了,这里就对于Spark的CmbineByKey来进行分析,怎样实现自定义Key,无论书MR还是CmbineByKey的核心思想都是相同的key为一组.

简要分析AppendOnlyMap

AppendOnlyMap
  // Holds keys and values in the same array for memory locality; specifically, the order of
  // elements is key0, value0, key1, value1, key2, value2, etc.
  //key和value在同一个数组,通过hash值获取position来确定元素
  private var data = new Array[AnyRef](2 * capacity)
  
  //该方法对于combineByKey来说,每一项都会调用该方法,key是只新增不会修改,value会进行update
  //key的新增,在第一次出现key的时候会计算出pos,并插入
  def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    assert(!destroyed, destructionMessage)
    val k = key.asInstanceOf[AnyRef]
    if (k.eq(null)) {
      if (!haveNullValue) {
        incrementSize()
      }
      nullValue = updateFunc(haveNullValue, nullValue)
      haveNullValue = true
      return nullValue
    }
    //这里通过key来获取pos,所以这里我们要重写我们key的hashcode
    var pos = rehash(k.hashCode) & mask
    var i = 1
    while (true) {
      val curKey = data(2 * pos)
      //不存在该key,新增    
      if (curKey.eq(null)) { 
        //  updateFunc 就是createCombiner或者mergeValue方法
        val newValue = updateFunc(false, null.asInstanceOf[V])
        data(2 * pos) = k
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        incrementSize()
        return newValue
       //如果存在,还需要判断equals方法,hashcode可能相同  ,所以key的equals也要重写   
      } else if (k.eq(curKey) || k.equals(curKey)) {
        //这里需要更新新值  
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        return newValue
      } else {
        val delta = i
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

  /** Iterator method from Iterable */
  override def iterator: Iterator[(K, V)] = {
    assert(!destroyed, destructionMessage)
    new Iterator[(K, V)] {
      var pos = -1

      /** Get the next value we should return from next(), or null if we're finished iterating */
      def nextValue(): (K, V) = {
        if (pos == -1) {    // Treat position -1 as looking at the null value
          if (haveNullValue) {
            return (null.asInstanceOf[K], nullValue)
          }
          pos += 1
        }
        while (pos < capacity) {
          if (!data(2 * pos).eq(null)) {
            return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
          }
          pos += 1
        }
        null
      }

      override def hasNext: Boolean = nextValue() != null

      override def next(): (K, V) = {
        val value = nextValue()
        if (value == null) {
          throw new NoSuchElementException("End of iterator")
        }
        pos += 1
        value
      }
    }
  }

  override def size: Int = curSize

  /** Increase table size by 1, rehashing if necessary */
  private def incrementSize() {
    curSize += 1
    if (curSize > growThreshold) {
      growTable()
    }
  }
	//ExternalSorter的insertAll方法定义update方法
	val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }

结论

自定义Key必须要重写 hashCode和equals方法,来进行key的比较,只有这样才能保证相同的key被分到一组

测试

package com.yzz.spark

import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
/**
  *
  * @time 2020/7/11 21:46
  * @author yzz 
  * @E-mail yzzstyle@163.com
  * @since 0.0.1
  */
object TestCombineByKey {

  def main(args: Array[String]): Unit = {
    val config = new SparkConf()
      .setAppName("WordCount")
      .setMaster("local")
      .setJars(List("C:\\work\\study\\spark_2020\\wordCount\\target\\wordCount-1.0-SNAPSHOT-jar-with-dependencies.jar"))
    val sc = new SparkContext(config)

    val p = new SchoolPartition(3)
    val rdd = sc.parallelize(List((Student("学校A", "学生A", 18),1), (Student("学校A", "学生B", 18),1), (Student("学校B","学生C" , 18),1), (Student("学校C","学生D", 18),1)))
    val rdd1 =  rdd.partitionBy(p)
    val rdd2 = rdd1.combineByKey((x: Int) => x, (x: Int, y: Int) => x + y, (x: Int, y: Int) => x + y,p)
    rdd2.collect()


  }

  class SchoolPartition(val partitions: Int) extends Partitioner with Serializable {
    override def numPartitions: Int = partitions

    override def getPartition(key: Any): Int = key match {
      case Student(school, _, _) =>
        val code = school.hashCode % partitions
        code + (if (code < 0) partitions else 0)
      case _ => 0
    }

    /**
      * 必须重写,这里决定了RDD之间的分区是否一致,直接影响Stage的划分,从而影响 shuffle
      *
      * @param obj
      * @return
      */
    override def equals(obj: Any): Boolean = obj match {
      case p: SchoolPartition =>
        p.partitions == partitions
      case _ => false
    }


  }

  case class Student(val school: String, val name: String, val age: Int) extends Serializable with Ordered[Student] {
    /**
      * 按照 school name age 来排序
      *
      * @param that
      * @return
      */
    override def compare(that: Student): Int = {
      val c1 = school.compareTo(that.school)
      if (c1 == 0) {
        val c2 = name.compareTo(that.name)
        if (c2 == 0) age - that.age else c2
      } else c1
    }

    /**
      * 相同的学校为一组
      *
      * @param other
      * @return
      */
    override def equals(other: Any): Boolean = other match {
      case that: Student =>
        that.school.equals(school)
      case _ => false
    }

    override def hashCode(): Int = {
      val state = Seq(school)
      state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
    }

    override def toString: String = s"[$school,$name,$age]"


    object Student {
      def apply(school: String, name: String, age: Int): Student = new Student(school, name, age)
    }
  }


}

总结

需要自定义key,对于需要进行分组的操作都需要重写 hashCodeequals方法,对于排序操作,都需要通过提供Ordering或者Ordered才能正常对自定义Key进行排序。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值