解法一:分块哈希计数(易理解,效率高)
思路:
关键字的范围是1 ~ 100000,建立hash数组cnt[100001],cnt[i]即栈中关键字i的数量,即从1到100000遍历cnt[],易找到从小到大第K个关键字,但会消耗大量的时间,因此考虑分块的思想。
分块查找当将长N的表划分为sqrt(N)块(考虑边界应向上取整),每块sqrt(N)个关键字时,获得最高效率(顺序查找);
建立块总关键字数统计数组block[317],block[i]即第i块(关键字 i * 317至 (i + 1) * 317 - 1)中关键字总数;
检索部分:
int sum = 0, K = ( s.size() + 1 ) / 2, i, j;
for( i = 0; i < 317; ++i )//先按块检索
{
if( sum + block[i] >= K )
break;
sum += block[i];
}
for( j = i * 317; j < (i + 1) * 317; ++j )//确定目标块后,在目标块中检索
{
sum += cnt[j];
if( sum >= K )
{
printf("%d\n", j);
break;
}
}
代码:
#include <vector>
#include <stack>
#include <string>
#include <iostream>
using namespace std;
int main()
{
vector<int> block_cnt(316, 0), cnt(100001, 0);
int N;
stack<int> s;
string cmd;
cin >> N;
for( int i = 0, key; i < N; ++i )
{
cin >> cmd;
if( cmd == "Push" )
{
cin >> key;
s.push(key);
++cnt[key];
++block_cnt[key / 317];
}
else if( cmd == "Pop" && s.size() )
{
--cnt[s.top()];
--block_cnt[s.top() / 317];
cout << s.top() << endl;
s.pop();
}
else if( cmd == "PeekMedian" && s.size() )
{
for( int j = 0, sum = 0, flag = 0; j < 316 && !flag; sum += block_cnt[j++] )
if( sum + block_cnt[j] >= (s.size() + 1) / 2 )
for( int k = j * 317; (k < j * 317 + 317) && !flag; sum += cnt[k++] )
if( sum + cnt[k] >= (s.size() + 1) / 2 )
{
cout << k << endl;
flag = 1;
}
}
else cout << "Invalid" << endl;
}
}
解法二:树状数组BIT+二分查找
思路:
树状数组逻辑比较复杂,大家可以自行了解。
int bit[100001]为树状数组,bit[i]存储了栈中关键字 i - lowbit(i) + 1至i的数量,而如下函数getsum()即返回了了栈中小于等于关键字x的总关键字数;
int getsum( int x )
{
int sum = 0;
for( int i = x; i > 0; i -= lowbit(i) )
sum += bit[i];
return sum;
}
再利用折半查找,确定第K = ((N+1)/2)小的关键字
int K = ( s.size() + 1 ) / 2, l = 1, r = 100000, m;
while( l < r )
{
m = (l + r ) / 2;
if( getsum( m ) >= K )
r = m;
else l = m + 1;
}
printf("%d\n", l);
代码:
#include <iostream>
#include <vector>
#include <string>
using namespace std;
vector<int> s, bit(100001, 0);
int lowbit( int x )
{
return x&(-x);
}
void update( int x, int k )
{
for( int i = x; i < 100001; i += lowbit(i) )
bit[i] += k;
}
int getsum( int x )
{
int sum = 0;
for( int i = x; i > 0; i -= lowbit(i) )
sum += bit[i];
return sum;
}
int main()
{
int N;
scanf("%d", &N);
for( int i = 0, n; i < N; ++i )
{
string str;
cin >> str;
if( str == "Push" )
{
scanf("%d", &n);
s.push_back(n);
update( n, 1 );
}
else if( !s.size() )
printf("Invalid\n");
else if( str == "Pop" )
{
printf("%d\n", s[ s.size() - 1 ]);
update( s[ s.size() - 1 ], -1 );
s.pop_back();
}
else
{
int pos = ( s.size() + 1 ) / 2, l = 1, r = 100000, m;
while( l < r )
{
m = (l + r ) / 2;
if( getsum( m ) >= pos )
r = m;
else l = m + 1;
}
printf("%d\n", l);
}
}
}