写了一个upper_bound的实现。其中递归使用二分法求解最上界,虽然写的完全不像STL的风格,但是练手还是可以的。
#include<iostream> #include<iterator> #include<cstring> #include<cassert> using namespace std; int UpperBound(int* a, int start, int end , const int& value){ int mid = 0; int index = 0; int up_index = 0; if(a[end]<value) // 特判比最后一个数大时,end+1 return end+1; //这里是经典的二分 while(start<= end){ mid = start + (end - start)/2; if(a[mid] == value){ index = mid + 1; // 去递归的找一下上界 up_index = UpperBound(a,mid+1,end,value); // 如果找到的上界比现在的还小,那么就用现在的 return up_index > index? up_index : index; } else if(a[mid]<value){ start = mid+1; } else{ end = mid-1; // 记录上界 index = mid; } } return index; }
如果原数组中没有存在那个元素,就根本没有调用那个递归程序,递归只有在出现多个此元素时才会调用。另外中间递归调用段地方还可以改写为:
if(a[mid] == value){ index = mid + 1; /* up_index = UpperBound(a,mid+1,end,value); return up_index > index? up_index : index; */ // 因为只有在递归调用中start>end时,才会返回一个index = 0 return (mid+1>end) ? index : UpperBound(a,mid+1,end,value); }
这样写完后写一下测试代码,顺便wrap一层upper_bound:
int upper_bound(int * a,int n, const int& value){ // 使用宏可以在编译时去除 #ifdef TEST // pre-condition assert(NULL != a && n>0); // check the array a is sorted by elements for(int i = 1; i < n; i++){ assert(a[i-1]<=a[i]); } int * tmp_a = new int[n]; memcpy(tmp_a,a,sizeof(int)*n); #endif int index = UpperBound(a,0,n-1,value); #ifdef TEST // post-condition //check whether the array is changed or not for(int i = 0; i < n; i++) assert(a[i] == tmp_a[i]); if(index == 0){ // min assert(a[index]>value); } else if(index == n){ // max assert(a[index-1]<=value); } else{ assert(a[index]>value && a[index-1]<=value); } delete [] tmp_a; #endif return index; }
如果希望别人不调用UpperBound而只调用upper_bound,可以使用static 关键字, 将UpperBound限制在只在本文件中使用。别人调用就只能调用upper_bound()函数。不过STL的教学源码比我这精简的多,根本无须使用递归。真正的STL源码变量名称会使用下划线__作为起始。
template <class ForwardIterator, class T> ForwardIterator upper_bound ( ForwardIterator first, ForwardIterator last, const T& value ) { ForwardIterator it; iterator_traits<ForwardIterator>::distance_type count, step; count = distance(first,last); while (count>0) { it = first; step=count/2; advance (it,step); if (!(value<*it)) // or: if (!comp(value,*it)), for the comp version { first=++it; count-=step+1; } else count=step; } return first; }
这里的advance()函数定义类似:
template<class Iterator, class Distance> void advance(Iterator& i, Distance n); 类似于 i = x.begin()+n; i is an iterator of x
最后把主函数贴上做结:
int main(int argc, char* argv[]) { // 这里你可以使用random的方法进行测试。 int x[] = {1,2,3,4,5,5,5,5,9}; vector<int> v(x,x+9); sort(v.begin(),v.end()); cout << (int)(upper_bound(v.begin(),v.end(),9)-v.begin()); cout << upper_bound(x,9,9); return 0; }