头文件algorithm中有众多排序相关的函数,其中包括nth_element函数。
nth_element()函数找到将[first, last)区间排序后的第n个元素,并将该元素置于第n个位置。
函数原型
template<class RandomAccessIterator>
void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last);
template<class RandomAccessIterator, class Compare>
void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last, Compare comp);
源码
// nth_element() and its auxiliary functions.
template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*)
{
while (__last - __first > 3)
{
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
}
template <class _RandomAccessIter>
inline void nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last)
{
__STL_REQUIRES(_RandomAccessIter, _Mutable_RandomAccessIterator);
__STL_REQUIRES(typename iterator_traits<_RandomAccessIter>::value_type,
_LessThanComparable);
__nth_element(__first, __nth, __last, __VALUE_TYPE(__first));
}
template <class _RandomAccessIter, class _Tp>
_RandomAccessIter __unguarded_partition(_RandomAccessIter __first,
_RandomAccessIter __last,
_Tp __pivot)
{
while (true)
{
while (*__first < __pivot)
++__first;
--__last;
while (__pivot < *__last)
--__last;
if (!(__first < __last))
return __first;
iter_swap(__first, __last);
++__first;
}
}
该算法是原地操作,空间复杂度为O(1)。时间复杂度为O(n)。
原理
_unguarded_partition 就是快速排序的 partition, 将数组分成两部分,左边的元素都小于或者等于 pivot, 右边的元素都大于或者等于 pivot.
从上述代码可以看出, nth_element 采用的 pivot 是 首元素,尾元素,中间元素,三个数的median.
通过_unguarded_partition 将数组分成两部分,
如果 nth 这个迭代器在左半边,则继续在左半边搜索;
若 nth 在右半边,则在右半边搜索;
直到数组的长度 <= 3,时, 采用插入排序。这时 nth 迭代器所指向的数就归位了,而且它的左边元素都小于或者等于它, 右边元素都大于或者等于它。
_unguarded_partition在长度为m的数组上花费的时间为o(n);每一次搜索之后,如果还未达到目标位置,则下一次搜索将是本次搜索长度的一半。
把该函数应用在长度为n的无序数组上,在最坏的情况下,所需的时间为:
n
<
n
+
n
2
+
n
4
+
.
.
.
+
1
<
2
n
n<n+\frac{n}{2}+\frac{n}{4}+...+1<2n
n<n+2n+4n+...+1<2n
所以时间复杂度在O(n)与O(2n)之间,即为线性时间复杂度。
注意:得到的结果数组中第n个位置的左边、右边的元素不一定有序