Spark版本2.4.0
Spark的RadixSort基数排序实现的排序容器基于一个LongArray实现。
在LongArray中,一个元素的长度为8字节,当排序的时候,将是每8个字节确定一个元素。
public static int sort(
LongArray array, long numRecords, int startByteIndex, int endByteIndex,
boolean desc, boolean signed) {
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 2 <= array.size();
long inIndex = 0;
long outIndex = numRecords;
if (numRecords > 0) {
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (counts[i] != null) {
sortAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
return Ints.checkedCast(inIndex);
}
当调用sort()方法进行排序的时候,需要一个LongArray参数作为被排序的数组。
NumRecords代表数组中的元素个数,由于LongArray中一个元素是8字节,startByteIndex和endByteIndex代表参与基数排序的起始字节和结束字节,用来划分参与到排序的范围。两者在0到7之间来确定。
由于在排序中,中间过程需要在元素组中进行存放,,所以LongArray的大小必须是numRecords被排序数量的倍。
在关键参数确定完之后,需要通过getCounts()构建8字节各字节大小在整个数组中的数量直方图。通过二维数组counts来存放。
private static long[][] getCounts(
LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
// If all the byte values at a particular index are the same we don't need to count it.
long bitwiseMax = 0;
long bitwiseMin = -1L;
long maxOffset = array.getBaseOffset() + numRecords * 8L;
Object baseObject = array.getBaseObject();
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
bitwiseMax |= value;
bitwiseMin &= value;
}
long bitsChanged = bitwiseMin ^ bitwiseMax;
// Compute counts for each byte index.
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
counts[i] = new long[256];
// TODO(ekl) consider computing all the counts in one pass.
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
}
}
}
return counts;
}
这里会构造一个8行256列的二维数组用来作为数组中各个元素每个位置上字节大小出现数量的直方图,比如数组中各个元素第8个字节为0的情况共有三个元素,这样在数组counts[7][0]则为3,一次遍历LongArray中的所有元素,统计各个字节位置上大小统计的出现总数,形成一个直方图。上文提到的startByteIndex和endByteIndex参数用来确定统计的字节范围。
在完成各字节出现数量的直方图统计后,将会从高位开始一次进行基数排序。具体单字节的排序在sortAtByte()方法中。
private static void sortAtByte(
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 8L;
for (long offset = baseOffset; offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
Platform.putLong(baseObject, offsets[bucket], value);
offsets[bucket] += 8;
}
}
具体的排序方式如下,当正在排序第1个字节的大小的时候,先获得被排序的元素第一个字节的大小,根据之前直方图中该字节大小前面个数的数量来确定当前轮次该元素的插入位置,比如该元素第一个字节为2,则直方图中该字节0和1共出现7次,则这轮排序这个元素将被插入到第8个位置当中,同时下一个该字节为2的元素将会被插入到第9个位置上,防止冲突。
以此直到从起始字节排序到最后一个字节,一次基数排序也宣告结束。