题意
n个数字,问m是多少个区间的中位数?偶数个数字的区间中位数认为是中间靠左的那个数字,比如(1,3,7,8)中位数为3。
E1:这n个数字是1~n的一个排列
E2:n个数字,数字大小不超过2e5
题解
做法①(可过E1,过不了E2):
思路请看:51Nod ~ 1682 ~ 中位数计数 (思维)
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5+5;
const int FIX = 2e5;
typedef long long LL;
int n, m, a[MAXN], pos, b[MAXN*2];
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++)
{
scanf("%d", &a[i]);
if (a[i] == m) pos = i;
}
int cnt = 0;
LL ans = 0;
for (int i = pos; i >= 0; i--)
{
if (a[i] > m) cnt++;
if (a[i] < m) cnt--;
b[FIX+cnt]++;
}
cnt = 0;
for (int i = pos; i < n; i++)
{
if (a[i] > m) cnt++;
if (a[i] < m) cnt--;
ans += b[FIX-cnt];
ans += b[FIX-cnt+1];
}
printf("%lld\n", ans);
return 0;
}
/*
5 4
2 4 5 3 1
*/
做法②(E1,2均可过):
中位数小于等于m的区间 - 小于等于m-1的区间=中位数等于m的区间 ,现在我们需要搞一个计算中位数小于等于x的区间个数的函数。首先如果中位数小于等于x,那么<=x的数字一定大于>x的数字,我们记<=x的为+1,>x的数字记为-1,我们用变量s记录这个值,如果对于[L,R]区间(L<R),L位置时的s小于R位置时的s,证明[L,R]区间是一个符合要求的区间。
枚举右端点,维护s变量,对于右端点 R,怎么得到s_L <= s_R 的 L 的个数呢?
O(n)做法:
数组cnt[i]用于统计 s = i 的点的个数(由于s有负值所以我们以n作为0点即可),add维护s_L <= s_i 的 L 的个数
如果当前元素a[i]<=x,s要++,那么对于 i 位置符合要求的点比 i-1 多了cnt[++s]个
如果当前元素a[i]>x,s要--,那么对于 i 位置符合要求的点比 i-1 少了cnt[s--]个
统计答案,每次ans+=add即可,最终ans就是中位数大于等于x的区间数。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5+5;
typedef long long LL;
int n, m, a[MAXN], cnt[MAXN*2];
LL slove(int x)
{
memset(cnt, 0, sizeof(cnt));
int s = n; cnt[n] = 1;
LL add = 0, ans = 0;
for (int i = 0; i < n; i++)
{
if (a[i] <= x) add += cnt[++s];//add += cnt[++s]同s++; add += cnt[s];
else add -= cnt[s--];//add -= cnt[s--]同add -= cnt[s]; s--;
cnt[s]++;
ans += add;
}
return ans;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
printf("%lld\n", slove(m) - slove(m-1));
return 0;
}
/*
5 4
1 4 5 60 4
*/
O(n*logn)(BIT做法):
数组cnt[i]用于统计 s = i 的个数(由于s有负值所以我们以n作为0点即可)
BIT维护cnt[s]这个数组,枚举右端点,每次对于当前右端点,查询有cnt中有多少个小于等于s的值。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5+5;
typedef long long LL;
int n, m, a[MAXN];
struct BIT
{
int n, c[MAXN*2];
void init (int n)
{
this->n = n;
memset(c, 0, sizeof(c));
}
void add(int p, int x)
{
for (int i = p; i <= n; i += i&-i)
c[i] += x;
}
int sum(int p)
{
int ans = 0;
for (int i = p; i >= 1; i -= i&-i)
ans += c[i];
return ans;
}
}bit;
LL slove(int x)
{
bit.init(2*n+1);
bit.add(n+1, 1);
int s = n+1;
LL ans = 0;
for (int i = 1; i <= n; i++)
{
if (a[i] <= x) s++;
else s--;
bit.add(s, 1);
ans += bit.sum(s);
}
return ans;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
printf("%lld\n", slove(m) - slove(m-1));
return 0;
}
/*
5 4
1 4 5 60 4
*/