题意
有一个长度为n的数组a和m个操作,每个操作形如
1 x在序列的末尾添加一个数x
2 l r询问[l,r]的数的和
3 x把序列中所有数都异或上x
4把序列中所有数从小到大排序
n,m≤105,x,ai≤109
n
,
m
≤
10
5
,
x
,
a
i
≤
10
9
分析
我们可以对那些排好序的数维护一棵Trie,每个节点维护该点子树中每一位的1的个数;对新加进来的数维护一下前缀和。
显然Trie是资辞打异或标记的,然后前缀和也可以整体打标记。
然后每次排序就把所有数暴力扔进Trie里面就好了。
时间和空间复杂度都是
O(nlog2n)
O
(
n
l
o
g
2
n
)
。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=100005;
int n,m,a[N],pre[N*2][35],rt,bin[35],tot_tag,swap_tag,sz,tot;
struct tree{int l,r,bit[35],sz;LL w;}t[N*65];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void ins(int &d,int w,int x)
{
if (!d) d=++sz;
t[d].sz++;t[d].w=x;
for (int i=0;i<=30;i++) t[d].bit[i]+=(x&bin[i])>0;
if (w<0) return;
if (x&bin[w]) ins(t[d].r,w-1,x);
else ins(t[d].l,w-1,x);
}
LL query(int d,int w,int l,int r)
{
if (l==1&&r==t[d].sz)
{
LL ans=0;
for (int i=0;i<=30;i++)
if (tot_tag&bin[i]) ans+=(LL)bin[i]*(t[d].sz-t[d].bit[i]);
else ans+=(LL)bin[i]*t[d].bit[i];
return ans;
}
if (w<0) return (LL)(t[d].w^tot_tag)*(r-l+1);
int v=(swap_tag&bin[w])?t[t[d].r].sz:t[t[d].l].sz;
if (r<=v) return query((swap_tag&bin[w])?t[d].r:t[d].l,w-1,l,r);
else if (l>v) return query((swap_tag&bin[w])?t[d].l:t[d].r,w-1,l-v,r-v);
else return query((swap_tag&bin[w])?t[d].r:t[d].l,w-1,l,std::min(r,v))+query((swap_tag&bin[w])?t[d].l:t[d].r,w-1,std::max(l-v,1),r-v);
}
void extend(int x)
{
a[++tot]=x^tot_tag;
for (int i=0;i<=30;i++)
pre[tot][i]=pre[tot-1][i]+((a[tot]&bin[i])>0);
}
void build()
{
swap_tag=tot_tag;
while (tot) ins(rt,30,a[tot]),tot--;
}
LL solve(int l,int r)
{
LL ans=0;
if (l<=t[rt].sz) ans+=query(rt,30,l,std::min(r,t[rt].sz));
if (r>t[rt].sz)
{
l=std::max(1,l-t[rt].sz);r-=t[rt].sz;
for (int i=0;i<=30;i++)
if (tot_tag&bin[i]) ans+=(LL)bin[i]*(r-l+1-pre[r][i]+pre[l-1][i]);
else ans+=(LL)bin[i]*(pre[r][i]-pre[l-1][i]);
}
return ans;
}
int main()
{
bin[0]=1;
for (int i=1;i<=30;i++) bin[i]=bin[i-1]*2;
n=read();
for (int i=1,x;i<=n;i++) x=read(),extend(x);
m=read();
while (m--)
{
int op=read();
if (op==1)
{
int x=read();
extend(x);
}
else if (op==2)
{
int l=read(),r=read();
printf("%lld\n",solve(l,r));
}
else if (op==3)
{
int x=read();
tot_tag^=x;
}
else build();
}
return 0;
}