spark countByKey源码详解

countByKey首先使用map函数,将key,value形式的rdd数据转化为的value形式,然后调用countByValue,实现计数,属于复用函数。下面也会详细介绍countByValue

Pyspark 源码关于:countByKey

    def countByKey(self):
        """
        Count the number of elements for each key, and return the result to the
        master as a dictionary.

        >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(rdd.countByKey().items())
        [('a', 2), ('b', 1)]
        """
        return self.map(lambda x: x[0]).countByValue()

countByValue是数据每个Partition进行计数,转化为一个key:value的形式,也就是,每个出现的次数,在当前Partition;然后调用,reduce函数,进行合并统计计数;

countByValue源码如下:

    def countByValue(self):
        """
        Return the count of each unique value in this RDD as a dictionary of
        (value, count) pairs.

        >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
        [(1, 2), (2, 3)]
        """
        def countPartition(iterator):
            counts = defaultdict(int)
            for obj in iterator:
                counts[obj] += 1
            yield counts

        def mergeMaps(m1, m2):
            for k, v in m2.items():
                m1[k] += v
            return m1
        return self.mapPartitions(countPartition).reduce(mergeMaps)

countByValue调用的函数详细如下:

mapPartitions理解:

mapPartitions是map的一个变种。map的输入函数是应用于RDD中每个元素,而mapPartitions的输入函数是应用于每个分区(Partitions),也就是把每个分区(Partitions)中的内容作为整体来处理。
参数preservesPartitioning表示是否保留父RDD的partitioner分区信息。

mapPartitions源码如下:

 def mapPartitions(self, f, preservesPartitioning=False):
        """
        Return a new RDD by applying a function to each partition of this RDD.

        >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
        >>> def f(iterator): yield sum(iterator)
        >>> rdd.mapPartitions(f).collect()
        [3, 7]
        """
        def func(s, iterator):
            return f(iterator)
        return self.mapPartitionsWithIndex(func, preservesPartitioning)

mapPartitionsWithIndex源码:

mapPartitionsWithIndex函数作用同mapPartitions,不过提供了两个参数,第一个参数为分区的索引。

    def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
        """
        Return a new RDD by applying a function to each partition of this RDD,
        while tracking the index of the original partition.

        >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
        >>> def f(splitIndex, iterator): yield splitIndex
        >>> rdd.mapPartitionsWithIndex(f).sum()
        6
        """
        return PipelinedRDD(self, f, preservesPartitioning)

这个还在理解过程中-----

class PipelinedRDD(RDD):

    """
    Pipelined maps:
    这个类是所有转换操作返回回去的RDD类型,这个类继承了RDD类
    这个类重写了_jrdd属性,返回的jrdd是一个PythonRDD
    PythonRDD的父rdd是最初生成的rdd中的_jrdd
    也就是说,用户使用pyspark代码的时候,执行的jvm代码都是从PythonRDD开始
    >>> rdd = sc.parallelize([1, 2, 3, 4])
    >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]
    >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]

    Pipelined reduces:
    >>> from operator import add
    >>> rdd.map(lambda x: 2 * x).reduce(add)
    20
    >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
    20
    """

    def __init__(self, prev, func, preservesPartitioning=False, isFromBarrier=False):
        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
            # This transformation is the first in its stage:
            # 上一个rdd不是PipelinedRDD的话就把原始rdd._jrdd传递下去
            self.func = func
            self.preservesPartitioning = preservesPartitioning
            self._prev_jrdd = prev._jrdd
            self._prev_jrdd_deserializer = prev._jrdd_deserializer
        else:
            prev_func = prev.func
			#这个函数就是把上一个rdd的逻辑和当前的处理逻辑嵌套起来
            #prev_func是上一次转换时指定的函数
            #func是这一次转换时指定的函数
            def pipeline_func(split, iterator):
                return func(split, prev_func(split, iterator))
            self.func = pipeline_func
            self.preservesPartitioning = \
                prev.preservesPartitioning and preservesPartitioning
            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
        self.is_cached = False
        self.is_checkpointed = False
        self.ctx = prev.ctx
        self.prev = prev
        self._jrdd_val = None
        self._id = None
        self._jrdd_deserializer = self.ctx.serializer
        self._bypass_serializer = False
        self.partitioner = prev.partitioner if self.preservesPartitioning else None
        self.is_barrier = prev._is_barrier() or isFromBarrier

    def getNumPartitions(self):
    	#获得 partitions num
        return self._prev_jrdd.partitions().size()

    @property
    def _jrdd(self):
    	#构造PythonRDD
        if self._jrdd_val:
            return self._jrdd_val
        if self._bypass_serializer:
            self._jrdd_deserializer = NoOpSerializer()

        if self.ctx.profiler_collector:
            profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
        else:
            profiler = None
		#把用户的python代码序列化
        wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
                                      self._jrdd_deserializer, profiler)
        #构造一个新的_jrdd 类型是PythonRDD,此rdd的父rdd是最初的数据源对应的_jrdd
        #当在此rdd的基础上调用action的时候,传递进去的_jrdd就是这里返回的东西
        python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
                                             self.preservesPartitioning, self.is_barrier)
        self._jrdd_val = python_rdd.asJavaRDD()

        if profiler:
            self._id = self._jrdd_val.id()
            self.ctx.profiler_collector.add_profiler(self._id, profiler)
        return self._jrdd_val

    def id(self):
        if self._id is None:
            self._id = self._jrdd.id()
        return self._id

    def _is_pipelinable(self):
        return not (self.is_cached or self.is_checkpointed)

    def _is_barrier(self):
        return self.is_barrier

reduce的源码如下:这是一个action函数,因为里面有collect()

    def reduce(self, f):
        """
        Reduces the elements of this RDD using the specified commutative and
        associative binary operator. Currently reduces partitions locally.

        >>> from operator import add
        >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
        15
        >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
        10
        >>> sc.parallelize([]).reduce(add)
        Traceback (most recent call last):
            ...
        ValueError: Can not reduce() empty RDD
        """
        f = fail_on_stopiteration(f)

        def func(iterator):
            iterator = iter(iterator)
            try:
                initial = next(iterator)
            except StopIteration:
                return
            yield reduce(f, iterator, initial)

        vals = self.mapPartitions(func).collect()
        if vals:
            return reduce(f, vals)
        raise ValueError("Can not reduce() empty RDD")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值