STL 中的划分算法 partition 的函数原型如下
template <class ForwardIterator, class Predicate>
ForwardIterator partition(ForwardIterator first, ForwardIterator last, Predicate pred);
四点说明:
1、功能:将 [first, last) 中所有满足 pred 的元素置于不满足 pred 的元素前面。
2、返回值:设返回的迭代器为 i,则对 [first, i) 中的任意迭代器 j,pred(*j) 为真,对 [i, last) 中的任意迭代器 k,pred(*k) 为假。
3、要求:ForwardIterator 必须满足 ValueSwappable。
4、复杂度:如果 ForwardIterator 符合 BidirectionalIterator 的要求,则最多进行 (last - first) / 2 次交换;否则最多 last - first 次交换。且执行 last - first 次的 pred 操作。
源码如下:
/********************************************************************
created: 2014/04/25 15:37
filename: partition.cpp
author: Justme0 (http://blog.csdn.net/justme0)
purpose: partition
*********************************************************************/
#include <cstdio>
#include <cstdlib>
typedef int T;
template <class ForwardIterator1, class ForwardIterator2>
inline void iter_swap(ForwardIterator1 a, ForwardIterator2 b) {
T tmp = *a; // 源码中的 T 由迭代器的 traits 得来,这里简化了
*a = *b;
*b = tmp;
}
template <class Type>
inline void swap(Type &a, Type &b) {
Type tmp = a;
a = b;
b = tmp;
}
/*
** 设返回值为 mid,则[first, mid)中迭代器指向的值满足 pred;
** [mid, last)中迭代器指向的值不满足 pred
** 为双向迭代器设计的算法
** 迭代器用到了自减操作
*/
template <class BidirectionalIterator, class Predicate>
BidirectionalIterator partition(BidirectionalIterator first, BidirectionalIterator last, Predicate pred) {
while(true) {
while(true) {
if (first == last) {
return first;
} else if (pred(*first)) {
++first;
} else {
break;
}
}
--last;
while(true) {
if (first == last) {
return first;
} else if (!pred(*last)) {
--last;
} else {
break;
}
}
iter_swap(first, last);
++first;
}
}
/*
** 为前向迭代器设计的算法
** 这里为了不与为双向迭代器设计的 partition 同名,加了后缀以区别
** 源码中由 traits 得到不同类型的迭代器标记形成重载
*/
template <class ForwardIterator, class Predicate>
ForwardIterator partition2(ForwardIterator first, ForwardIterator last, Predicate pred) {
if (first == last) {
return first;
}
while (pred(*first)) {
if (++first == last) {
return first;
}
}
ForwardIterator next = first;
while (++next != last) {
if (pred(*next)) {
swap(*first, *next);
++first;
}
}
return first;
}
/*
** 设返回值为 mid,则[first, mid)中迭代器指向的值小于等于 pivot;
** [mid, last)中迭代器指向的值大于等于 pivot
** 这是 STL 内置的算法,会用于 nth_element, sort 中
** 笔者很困惑为什么不用 partition
*/
template <class RandomAccessIterator, class Type>
RandomAccessIterator unguarded_partition(RandomAccessIterator first, RandomAccessIterator last, Type pivot) {
while(true) {
while (*first < pivot) {
++first;
}
--last;
while (pivot < *last) { // 若 std::partition 的 pred 是 IsLess(pivot),这里将是小于等于
--last;
}
if (!(first < last)) { // 小于操作只适用于 random access iterator
return first;
}
iter_swap(first, last);
++first;
}
}
template <class Type>
struct IsLess {
IsLess(const Type &pivot) : m_pivot(pivot) {}
bool operator()(const Type &other) const {
return other < m_pivot;
}
private:
Type m_pivot;
};
void print(int *array, int len) {
for (int i = 0; i < len; ++i) {
printf("%d ", array[i]);
}
printf("\n");
}
int main(void) {
{
int arr[] = {2, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20};
int size = sizeof arr / sizeof *arr;
int *mid = partition(arr, arr + size, IsLess<int>(22));
print(arr, size); // 2 20 12 17 17 40 33 23 22 30 30
printf("partition return %d\n", mid - arr); // partition return 5
}
{
int arr[] = {2, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20};
int size = sizeof arr / sizeof *arr;
int *mid = partition2(arr, arr + size, IsLess<int>(22));
print(arr, size); // 2 17 17 12 20 40 30 23 22 30 33
printf("partition2 return %d\n", mid - arr); // partition2 return 5
}
{
int arr[] = {2, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20};
int size = sizeof arr / sizeof *arr;
int *mid = unguarded_partition(arr, arr + size, 22);
print(arr, size); // 2 20 12 17 22 17 40 23 33 30 30
printf("unguarded_partition return %d\n", mid - arr); // unguarded_partition return 6
}
system("PAUSE");
return 0;
}
如果大家对程序有任何疑问可以在下面回复,因为当中省略了很多 STL 的设计技巧以突出算法。