树状数组+trie树。
题目有两个要求:
1.求某一区间的种类数
2.在此基础上要满足一定条件
如果把这题拆开,这两个问题都是好做的。对于问题一,可以用树状数组离线,在
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)的时间内得到答案。对于问题二,判断
a
⨁
b
≤
c
a\bigoplus b\leq c
a⨁b≤c,可以用01trie树逐位判断。那么对于二者的结合,是否由一种数据结构可以既可以得到特定区间的种类,又可以判断异或关系式呢?
答案是肯定的,可以用树状数组+trie树。
具体说就是在树状数组每个节点上建立01trie树。就以样例为例,建成的数据结构如下图所示:
仅仅建立了trie树显然是不够的,因为trie树只能让我们知道这个树状数组结点管辖的区间内二进制数位的集合。那么如何从trie树上得到数字个数的信息?只需要记录每加入一个数字,经过结点的次数。下面给出一个例子:
然后就是怎么判断
a
⨁
b
≤
c
a\bigoplus b\leq c
a⨁b≤c,老套路,在字典树上与c逐位比较。
下面说说细节
1.关于如何在树状数组每个节点上建立trie树。
一开始我是定义了一个结构体数组,如下如所示:
struct t
{
int trie[maxn][2];
int size[maxn];
int cnt;
}tr[1000];
虽然这样容易理解,但是会mle。优化的方法类似于链式前向星,要把
n
n
n个trie树合在一个trie树上。具体就是用rt[]代表树的根节点在的下标。
2.关于数组大小
树状数组有
n
n
n个节点,平均每个结点管辖
l
o
g
n
logn
logn个节点,那么共有
32
n
l
o
g
n
32nlogn
32nlogn个节点,那么trie数组第一维要开1e5*400。
完整代码:
#include<bits/stdc++.h>
#define FAST ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define INF 0x3f3f3f3f
typedef long long ll;
const int maxn = 1e5+5;
using namespace std;
int cnt;
int trie[maxn*400][2];
int rt[maxn],size[maxn*400],ans[maxn];
struct que
{
int l,r;
int a,b;
int id;
}q[maxn];
bool cmp(que x, que y) { return x.r<y.r; }
int n,m;
int a[maxn],have[maxn];
void insert(int &pos, int num, int val)
{
if (!pos) pos=++cnt;
int now=pos;
for (int j=16; j>=0; j--)
{
int bit=(num>>j)&1;
if (!trie[now][bit]) trie[now][bit]=++cnt;
now=trie[now][bit];
size[now]+=val;
}
}
ll query(int pos, int a, int b)
{
if (!pos) return 0;
int now=pos;
ll ans=0;
for (int j=16; j>=0; j--)
{
int bt1=(b>>j)&1;
int bt2=(a>>j)&1;
if (bt1==0) now=trie[now][bt2];
else if (bt1==1)
{
ans+=size[trie[now][bt2]];
now=trie[now][bt2^1];
}
if (now==0) break;
}
return ans+size[now];
}
ll lowbit(ll x) { return (x&(-x)); }
void add(int pos, int num, int val)
{
int x=pos;
while(x<=n)
{
insert(rt[x],num,val);
x+=lowbit(x);
}
}
ll ask(int x, int a, int b)
{
ll ans=0;
while(x)
{
ans+=query(rt[x],a,b);
x-=lowbit(x);
}
return ans;
}
int main()
{
FAST;
cin>>n;
for (int i=1; i<=n; i++) cin>>a[i];
cin>>m;
for (int i=1; i<=m; i++)
{
cin>>q[i].l>>q[i].r>>q[i].a>>q[i].b;
q[i].id=i;
}
sort(q+1,q+m+1,cmp);
int last=0;
for (int i=1; i<=m; i++)
{
for (int j=last+1; j<=q[i].r; j++)
{
if (have[a[j]]) add(have[a[j]],a[j],-1);
add(j,a[j],1);
have[a[j]]=j;
}
last=q[i].r;
ans[q[i].id]=ask(q[i].r,q[i].a,q[i].b)-ask(q[i].l-1,q[i].a,q[i].b);
}
for (int i=1; i<=m; i++) cout<<ans[i]<<endl;
return 0;
}