Sum of xor sum ACM-ICPC 2017 Asia Xi’an(线段树维护子区间合并)
题意:给定n个数
a
1
,
a
2
,
…
,
a
n
a_1,a_2,\dots,a_n
a1,a2,…,an,给出q个询问求[L,R]内所有子区间的异或和。比如,数组1,2,3中[1,3]的异或和为:
1
+
1
x
o
r
2
+
1
x
o
r
2
x
o
r
3
+
2
+
2
x
o
r
3
+
3
1+1xor2+1xor2xor3+2+2xor3+3
1+1xor2+1xor2xor3+2+2xor3+3
思路:题目要我们求异或和,因此我们需要将每一个
a
i
a_i
ai拆分成二进制位。把每个数的每一位提出来后,我们需要统计出现奇数个1的区间。
- 对于区间的合并,我们需要统计的是:左边奇数个1的区间数+右边奇数个1区间+横跨两边的奇数个1的区间数。
- 而横跨两边的答案就是:左边以mid结尾的奇数个1的区间数 × \times × 右边以mid+1为开头的偶数个1的区间数+左边以mid结尾的偶数个1的区间数 × \times ×右边以mid+1为开头的奇数个1的区间数
#include <iostream>
#include <algorithm>
#include <cstdio>
#define ls (rt<<1)
#define rs ((rt<<1)|1)
#define ll long long
using namespace std;
const int maxn=1e5+5,INF=0x3f3f3f3f;
const int mod=1e9+7;
int t,n,q;
int a[maxn];
struct Node
{
ll ans[21],lx[21],rx[21];
int num[21];
int len;
}ST[maxn<<2];
Node Merge(Node a,Node b)
{
Node ret;
for(int i=1;i<=20;++i)
{
ret.ans[i]=a.ans[i]+b.ans[i];
ret.ans[i]+=a.rx[i]*(b.len-b.lx[i]);
ret.ans[i]+=(a.len-a.rx[i])*b.lx[i];
if(a.num[i]&1)
ret.lx[i]=a.lx[i]+(b.len-b.lx[i]);
else
ret.lx[i]=a.lx[i]+b.lx[i];
if(b.num[i]&1)
ret.rx[i]=b.rx[i]+(a.len-a.rx[i]);
else
ret.rx[i]=b.rx[i]+a.rx[i];
ret.num[i]=a.num[i]+b.num[i];
}
ret.len=a.len+b.len;
return ret;
}
void Build(int rt,int L,int R)
{
if(L==R)
{
for(int i=1;i<=20;++i)
{
if(a[L]&(1<<i-1))
ST[rt].ans[i]=ST[rt].lx[i]=ST[rt].rx[i]=ST[rt].num[i]=1;
else
ST[rt].ans[i]=ST[rt].lx[i]=ST[rt].rx[i]=ST[rt].num[i]=0;
}
ST[rt].len=1;
return;
}
int mid=(L+R)>>1;
Build(ls,L,mid);
Build(rs,mid+1,R);
ST[rt]=Merge(ST[ls],ST[rs]);
}
Node Query(int rt,int l,int r,int L,int R)
{
if(l==L&&R==r)
return ST[rt];
int mid=(L+R)>>1;
if(r<=mid)
return Query(ls,l,r,L,mid);
else if(l>mid)
return Query(rs,l,r,mid+1,R);
else
{
Node a=Query(ls,l,mid,L,mid);
Node b=Query(rs,mid+1,r,mid+1,R);
return Merge(a,b);
}
}
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i)
scanf("%d",&a[i]);
Build(1,1,n);
while(q--)
{
int l,r;
scanf("%d%d",&l,&r);
Node tmp=Query(1,l,r,1,n);
ll ans=0;
for(int i=20;i>=1;--i)
ans=((ans<<1ll)%mod+tmp.ans[i])%mod;
printf("%lld\n",ans);
}
}
return 0;
}