std::sort
源码分析之快速排序
众所周知,快速排序分为三个部分,主元(pivot)选择,元素划分和递归。通常,我们应重点关注快速排序的主元选择和划分环节,这两个部分也是面试考察的高频考点。
主元选择
在《算法导论》中提到,主元选择会对快速排序的时间复杂度产生不可忽略的影响。选择一个好的主元,能提升快排的效率,而某些情况下,选择的主元会使快速排序的时间复杂度抵达下界—— O ( n 2 ) O(n^2) O(n2)。
在此问题上,通常的实现方法是从待排序的数组中抽取部分元素,并取中位数,以此避免最坏情况。虽然这种做法存在抽样的开销,但一定能避免遇到最坏情况,可以认为这样做是值得的。
主元选择具体的实现如下:
- 当数组长度小于
40
时,比较头尾以及正中间的元素,从三个元素中选取中位数。 - 当数组长度大于
40
时,将数组8等分,从9个等分点(包括首尾)中选取中位数。- 从数组首部三个等分点中选择一个中位数
- 从数组尾部三个等分点中选择一个中位数
- 从数组中部三个等分点钟选择一个中位数
- 从上述三个中位数中再次挑选
- 将选择的主元放在数组的中间位置,记为
mid
代码实现:
此实现是标准库实现的复刻,仅对变量命名,代码格式进行了调整,另外省略了用于DEBUG检查和traits特性的代码。
// 挑选三个位置的中位数,并将其置于 mid 中
// first, mid, last 是指向数组的迭代器,它们均需指向有效位置
// comp是用于比较的函数
template <typename It, typename Comp>
void medianOfThree(It first, It mid, It last, Comp comp){
if(comp(*mid, *first)){ // arr[mid] < arr[first]
iter_swap(mid, first); // 交换 mid 和 first 指向的值,不改变迭代器的位置!
}
// 此时必然有 arr[mid] > arr[first], 进一步比较 mid 和 last 的关系
if(comp(*last, *mid)){ // arr[last] < arr[mid]
iter_swap(last, mid); // last 成为最大值,中位数位将在 mid 和 first 中产生
if(comp(*mid, *first))
iter_swap(mid, first); // 比较 mid 和 first 的较大者,将其置于 mid 中
}
}
// 挑选九个位置的中位数,并置于 mid 中
template <typename It, typename Comp>
void medianOfNine(It first, It mid, It last, Comp comp){
const auto size = last - first; // 排序范围的长度,因为last指向尾元素,比实际长度小一
const auto step = (size + 1) >> 3; // 单步步长,实际上就是数组长度的八分之一
const auto twoStep = step << 1; // 两步步长,等于数组长度的四分之一
medianOfThree(first, first + step, first + twoStep, comp); // 比较前三个等分点
medianOfThree(mid - step, mid, mid + step, comp); // 比较中间三个等分点
medianOfThree(last - twoStep, last - step, last, comp); // 比较后三个等分点
// 再进行最后一轮比较,将得到的中位数置于 mid 中
medianOfThree(first + step, mid, last - step, comp);
}
划分
划分是将数组进行分割的操作,通常的划分方法是将数组按主元划分为两个部分,主元左侧为小于主元的数,右侧为大于主元的数。但这样的需要考虑多种情况以及边界条件,还有与主元相同值的处理。
在标准库的实现中,数组被划分为了三个部分,第一部分为小于主元的数,第二部分是和主元相同的数,第三部分是大于主元的数,即划分之后的数组形如:S...S P...P G...G
,S
、P
、G
分别代表小于,等于和大于主元的元素。
注意到,在上文主元选择中,我们将主元放在了中间mid
的位置上,故可以从mid
开始,向左右两侧扩展,并寻找不符合划分规则的元素,进行交换。
具体过程描述如下:
-
记主元位置为
pFirst
,那么初始的P区间为 [ pFirst , pFirst + 1 ) [\text{pFirst},\text{pFirst} + 1) [pFirst,pFirst+1),记pLast = pFirst + 1
,这样与主元相同的元素就位于pFirst
和pLast
迭代器描述的左闭右开区间之中了。- 完成初始化之后,判断
pFirst
能否向左扩展,即是否有*(pFirst - 1) == *pFirst
,满足条件就向左移动pFirst
。 - 同理,向右扩展
pLast
,注意PLast
描述的是一个开区间,故只需比较*pLast == *pFirst
。
auto mid = begin + (end - begin) / 2; median(begin, mid, end - 1, comp); // 将 pivot 置于 mid auto pFirst = mid, pLast = mid + 1; // elem in range [pFirst, pLast) equals pivot while(pLast < end && !comp(*pLast, *pFirst) && !comp(*pFirst, *pLast)){ ++pLast; // 如果pLast等于pivot,则向右扩展 pLast } // 注意此处判断等于的方法 // 同理,向左扩展 pFirst, 但是应注意开闭区间之间的区别 while(pFirst > begin && !comp(*pFirst, *(pFirst - 1)) && !comp(*(pFirst - 1), *pFirst)) { --pFirst; }
- 完成初始化之后,判断
-
至此完成了P区间的初步构造,接下来就是正式的划分过程了。记左侧还未划分的区间终点为
SLast
,右侧还未划分的区间起点为GFirst
,这样整个数组就被四个指针划分成为5个部分。初始化时,令sLast = pFirst
、gFirst = pLast
。/* * |------------|---------|---------|--------|----------| * | 未划分部分 | < pivot | = pivot | >pivot | 未划分部分 | * sLast pFirst pLast gFirst */ auto sFirst = pFirst, gFirst = pLast;
接下来只需要向左右扩展sLast
和gFirst
,完成整个数组的划分。
-
先对
gFirst
执行操作,向右移动gFirst
:- 当遇到一个比
piovt
小的元素时,在原地等待; - 当遇到一个等于
pivot
的元素时,将放在pLast
的位置上,并更新pLast
- 当遇到一个比
pivot
大的元素时,继续向右移动。
for (; gFirst < end; ++gFirst) { if (func(*pFirst, *gFirst)) // gFirst > pivot, do nothing continue; else if (func(*gFirst, *pFirst)) // gFirst < pivot, stop break; else if (gFirst != pLast) // gFirst == pivot and gFirst != pLast iter_swap(gFirst, pLast++); // swap and inc pLast else ++pLast; // gFirst == pLast, no need to swap }
- 当遇到一个比
-
同理,对
sLast
进行操作,向左移动sLast
:- 根据当前元素的大小情况,执行原地等待,与
pivot
交换,或向左移动等操作 - 注意,因为区间开闭的不同,对
sLast
的操作也略有不同
for (; begin < sLast; --sLast) { if(func(*(sLast - 1), *pFirst)) // do nothing continue; else if(func(*pFirst, *(sLast - 1))) // stop break; else if( pFirst-- != (sLast - 1)) // swap and dec pFirst iter_swap(pFirst, sLast - 1); }
- 根据当前元素的大小情况,执行原地等待,与
-
当左右指针都停止后,先检查是否划分完毕,即左右指针已经走到了顶端:
if(sLast == begin && gFirst == end) { return { pFirst,pLast }; }
-
接着考虑特殊情况——只有一侧的指针走到了顶端,另一侧指针还在半途,此时有一侧已经没有空间了,这时需要移动P区间,来腾出空间。
虽说是移动区间,但区间内的元素都是相同的,所以实际上只需要将区间的端点元素向相反方向进行一次交换就可以了
- 若是左侧没有空间,即当前情况如图所示:
/* * |---------|---------|--------|----------| * | < pivot | = pivot | >pivot | 未划分部分 | * begin(sLast) pFirst pLast gFirst */
在这种情况下,只需要将
pFirst
处的元素移动到pLast
处,更新P区间,这样就为左侧腾出了一个位置。if(sLast == begin) { // 当左侧空间已满 if (pLast != gFirst) // 如果 pLast 和 gFirst 重合,则没必要先腾空位 { iter_swap(pFirst, pLast); // 将 pivot区间右移一个单位,此时左侧有一个大于pivot的元 // 素,并且正好位于 pFirst 位置上 } ++pLast; iter_swap(gFirst++, pFirst++); // 交换并更新区间 }
- 对于右侧没有空间,做法也是类似的:
else if(gFirst == end) { if(--sLast != --pFirst) { iter_swap(sLast, pFirst); } iter_swap(pFirst, --pLast); }
- 最后,也是最平凡的情况,交换左右指针的元素,扩展左右区间
else { iter_swap(--sLast, gFirst); ++gFirst; }
以上就是快速排序的划分过程。显然划分的返回值是一个区间,由两个迭代器组成的pair,它们分别标志了划分后的起点和终点,位于这个区间中的元素都处在正确的位置上。
递归
递归是快速排序中最简单的一个部分,此处不再叙述。
完整代码
把上文出现的代码进行组合,并加入递归的过程,就可以得到完成的快速排序实现:
template <typename It, typename Comp>
void medianOfThree(It first, It mid, It last, Comp comp) {
if (comp(*mid, *first)) {
iter_swap(mid, first);
}
if (comp(*last, *mid)) {
iter_swap(last, mid);
if (comp(*mid, *first))
iter_swap(mid, first);
}
}
template <typename It, typename Comp>
void medianOfNine(It first, It mid, It last, Comp comp) {
const auto size = last - first;
const auto step = (size + 1) >> 3;
const auto twoStep = step << 1;
medianOfThree(first, first + step, first + twoStep, comp);
medianOfThree(mid - step, mid, mid + step, comp);
medianOfThree(last - twoStep, last - step, last, comp);
medianOfThree(first + step, mid, last - step, comp);
}
template <typename It, typename Comp>
void median(It first, It mid, It last, Comp comp) {
auto size = last - first + 1;
if(size < 3)
return;
if (size < 40)
medianOfThree(first, mid, last, comp);
else
medianOfNine(first, mid, last, comp);
}
template <typename It, typename Pred>
pair<It,It> partition(It begin, It end, Pred func) {
auto mid = begin + (end - begin) / 2;
median(begin, mid, end - 1, func);
auto pFirst = mid, pLast = mid + 1;
while(pLast < end && !func(*pLast, *pFirst)&& !func(*pFirst, *pLast)) {
++pLast;
}
while(pFirst > begin && !func(*pFirst, *(pFirst - 1)) && !func(*(pFirst - 1), *pFirst)) {
--pFirst;
}
auto sLast = pFirst, gFirst = pLast;
while (true) {
for (; gFirst < end; ++gFirst) {
if (func(*pFirst, *gFirst)) {
continue;
}
else if (func(*gFirst, *pFirst)) {
break;
}
else if (gFirst != pLast) {
iter_swap(gFirst, pLast++);
}
else {
++pLast;
}
}
for (; begin < sLast; --sLast) {
if(func(*(sLast - 1), *pFirst)) {
continue;
}
else if(func(*pFirst, *(sLast - 1))) {
break;
}else if( pFirst-- != (sLast - 1)) {
iter_swap(pFirst, sLast - 1);
}
}
if(sLast == begin && gFirst == end) {
return { pFirst,pLast };
}
if(sLast == begin) {
if (pLast != gFirst)
{
iter_swap(pFirst, pLast);
}
++pLast;
iter_swap(gFirst++, pFirst++);
}else if(gFirst == end) {
if(--sLast != --pFirst) {
iter_swap(sLast, pFirst);
}
iter_swap(pFirst, --pLast);
}
else {
iter_swap(--sLast, gFirst);
++gFirst;
}
}
}
template <typename It, typename Pred>
void quickSort(It begin, It end, Pred func = less<>{}) {
if (begin + 1 == end)
return;
auto mid = partition(begin, end, func);
if (begin < mid.first) {
quickSort(begin, mid.first, func);
}
if (mid.second < end) {
quickSort(mid.second, end, func);
}
}