题意:
有n个数,m次询问,每次询问 L R k,求在区间 [L,R] 中小于等于 k 的数有多少个。
思路:
用主席树来维护,每次只需要找到序列 b 中第一个等于 k 的数,那么要求的数必定在 b[1]~b[upper_bound(k)] 这个范围内,接下来就像线段树统计区间个数那样,若完全包含则直接加上 t[r].sum - t[l].sum 。否则就分两边递归统计。
一开始用求区间第 k 小模板写的,改了半天才知道是要 区间统计。。。
主席树
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<sstream>
#include<cstring>
#include<bitset>
#include<cstdio>
#include<string>
#include<deque>
#include<stack>
#include<cmath>
#include<queue>
#include<set>
#include<map>
#define mod 1000000007
using namespace std;
typedef long long ll;
const int maxn = 1e5+10;
struct node
{
int l,r;
int sum;
}t[maxn*40];
int n,m,a[maxn],b[maxn],rt[maxn],tot = 0;
void update(int x,int l,int r,int &p)
{
t[++tot] = t[p];
p = tot;
t[p].sum++;
if(l==r)
return ;
int mid = (l+r)>>1;
if(x<=mid)
update(x,l,mid,t[p].l);
else
update(x,mid+1,r,t[p].r);
}
int query(int st,int ed,int L,int R,int l,int r)
{
if(L<=l && r<=R)
return t[ed].sum-t[st].sum;
int mid = (l+r)>>1;
int ans = 0;
if(L<=mid)
ans += query(t[st].l,t[ed].l,L,R,l,mid);
if(R>mid)
ans += query(t[st].r,t[ed].r,L,R,mid+1,r);
return ans;
}
int main()
{
int T;
int cnt;
int l,r,x;
scanf("%d",&T);
for(int ca=1;ca<=T;ca++)
{
memset(rt,0,sizeof(rt));
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
b[i] = a[i];
}
sort(b+1,b+n+1);
cnt = unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;i++)
a[i] = lower_bound(b+1,b+cnt+1,a[i])-b;
for(int i=1;i<=n;i++)
{
rt[i] = rt[i-1];
update(a[i],1,cnt,rt[i]);
}
printf("Case %d:\n",ca);
while(m--)
{
scanf("%d%d%d",&l,&r,&x);
l++; r++;
int p = upper_bound(b+1,b+cnt+1,x)-b-1;
if(p)
printf("%d\n",query(rt[l-1],rt[r],1,p,1,cnt));
else
printf("0\n");
}
}
return 0;
}
树状数组
以数组的下标建立树状数组,把询问 和 每个点值 从小到大排序,在每一次查询前把比查询值小的值的下标加入树状数组,在查询一下在范围内有多少个
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<sstream>
#include<cstring>
#include<bitset>
#include<cstdio>
#include<string>
#include<deque>
#include<stack>
#include<cmath>
#include<queue>
#include<set>
#include<map>
#define mod 1000000007
using namespace std;
typedef long long ll;
const int maxn = 1e5+10;
struct node
{
int val;
int id;
}t[maxn];
struct query
{
int l,r;
int k;
int id;
}q[maxn];
int n,m,c[maxn],res[maxn];
bool cmp(node a,node b)
{
if(a.val==b.val)
return a.id<b.id;
return a.val<b.val;
}
bool cmp1(query a,query b)
{
if(a.k==b.k)
return a.id<b.id;
return a.k<b.k;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int val)
{
while(x<=n)
{
c[x] += val;
x += lowbit(x);
}
}
int ask(int x)
{
int ans = 0;
while(x)
{
ans += c[x];
x -= lowbit(x);
}
return ans;
}
int main()
{
int T;
scanf("%d",&T);
for(int ca=1;ca<=T;ca++)
{
memset(c,0,sizeof(c));
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
{
scanf("%d",&t[i].val);
t[i].id = i+1;
}
for(int i=0;i<m;i++)
{
scanf("%d%d%d",&q[i].l,&q[i].r,&q[i].k);
q[i].l++;
q[i].r++;
q[i].id = i+1;
}
sort(t,t+n,cmp);
sort(q,q+m,cmp1);
for(int i=0,j=0;i<m;i++)
{
while(j<n && t[j].val<=q[i].k)
{
add(t[j].id,1);
j++;
}
res[q[i].id] = ask(q[i].r)-ask(q[i].l-1);
}
printf("Case %d:\n",ca);
for(int i=1;i<=m;i++)
printf("%d\n",res[i]);
}
return 0;
}