spark RadixSort基数排序源码实现

Spark版本2.4.0

 

Spark的RadixSort基数排序实现的排序容器基于一个LongArray实现。

在LongArray中,一个元素的长度为8字节,当排序的时候,将是每8个字节确定一个元素。

public static int sort(
    LongArray array, long numRecords, int startByteIndex, int endByteIndex,
    boolean desc, boolean signed) {
  assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
  assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
  assert endByteIndex > startByteIndex;
  assert numRecords * 2 <= array.size();
  long inIndex = 0;
  long outIndex = numRecords;
  if (numRecords > 0) {
    long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
    for (int i = startByteIndex; i <= endByteIndex; i++) {
      if (counts[i] != null) {
        sortAtByte(
          array, numRecords, counts[i], i, inIndex, outIndex,
          desc, signed && i == endByteIndex);
        long tmp = inIndex;
        inIndex = outIndex;
        outIndex = tmp;
      }
    }
  }
  return Ints.checkedCast(inIndex);
}

当调用sort()方法进行排序的时候,需要一个LongArray参数作为被排序的数组。

NumRecords代表数组中的元素个数,由于LongArray中一个元素是8字节,startByteIndex和endByteIndex代表参与基数排序的起始字节和结束字节,用来划分参与到排序的范围。两者在0到7之间来确定。

由于在排序中,中间过程需要在元素组中进行存放,,所以LongArray的大小必须是numRecords被排序数量的倍。

 

在关键参数确定完之后,需要通过getCounts()构建8字节各字节大小在整个数组中的数量直方图。通过二维数组counts来存放。

private static long[][] getCounts(
    LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
  long[][] counts = new long[8][];
  // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
  // If all the byte values at a particular index are the same we don't need to count it.
  long bitwiseMax = 0;
  long bitwiseMin = -1L;
  long maxOffset = array.getBaseOffset() + numRecords * 8L;
  Object baseObject = array.getBaseObject();
  for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
    long value = Platform.getLong(baseObject, offset);
    bitwiseMax |= value;
    bitwiseMin &= value;
  }
  long bitsChanged = bitwiseMin ^ bitwiseMax;
  // Compute counts for each byte index.
  for (int i = startByteIndex; i <= endByteIndex; i++) {
    if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
      counts[i] = new long[256];
      // TODO(ekl) consider computing all the counts in one pass.
      for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
        counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
      }
    }
  }
  return counts;
}

这里会构造一个8行256列的二维数组用来作为数组中各个元素每个位置上字节大小出现数量的直方图,比如数组中各个元素第8个字节为0的情况共有三个元素,这样在数组counts[7][0]则为3,一次遍历LongArray中的所有元素,统计各个字节位置上大小统计的出现总数,形成一个直方图。上文提到的startByteIndex和endByteIndex参数用来确定统计的字节范围。

 

在完成各字节出现数量的直方图统计后,将会从高位开始一次进行基数排序。具体单字节的排序在sortAtByte()方法中。

private static void sortAtByte(
    LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
    boolean desc, boolean signed) {
  assert counts.length == 256;
  long[] offsets = transformCountsToOffsets(
    counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
  Object baseObject = array.getBaseObject();
  long baseOffset = array.getBaseOffset() + inIndex * 8L;
  long maxOffset = baseOffset + numRecords * 8L;
  for (long offset = baseOffset; offset < maxOffset; offset += 8) {
    long value = Platform.getLong(baseObject, offset);
    int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
    Platform.putLong(baseObject, offsets[bucket], value);
    offsets[bucket] += 8;
  }
}

具体的排序方式如下,当正在排序第1个字节的大小的时候,先获得被排序的元素第一个字节的大小,根据之前直方图中该字节大小前面个数的数量来确定当前轮次该元素的插入位置,比如该元素第一个字节为2,则直方图中该字节0和1共出现7次,则这轮排序这个元素将被插入到第8个位置当中,同时下一个该字节为2的元素将会被插入到第9个位置上,防止冲突。

 

以此直到从起始字节排序到最后一个字节,一次基数排序也宣告结束。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值