LYK loves games
题解
还是挺简单的。
首先
O
(
n
q
l
o
g
n
)
O\left(nqlog\,n\right)
O(nqlogn)的做法应该是很好想到的,可以树dp,考虑对每个
v
a
l
i
val_{i}
vali进行二进制拆分,对于每一位单独进行处理,这样的话异或就比较好操作了。
我们每次只需要合并点
u
,
v
u,v
u,v时将两个当前位置不同的方案数乘起来再乘上当前位的大小即可。
设点
u
u
u第
i
i
i位为
0
,
1
0,1
0,1时的方案数为
s
u
m
u
,
i
,
0
/
1
sum_{u,i,0/1}
sumu,i,0/1,只需要让
a
n
s
+
=
2
i
s
u
m
u
,
i
,
0
/
1
s
u
m
u
,
i
,
1
/
0
ans+=2^isum_{u,i,0/1}sum_{u,i,1/0}
ans+=2isumu,i,0/1sumu,i,1/0即可。
但是这样做的话每次修改时都要再对一条链上的点进行合并,时间复杂度达到了
O
(
n
q
l
o
g
n
)
O\left(nqlog\,n\right)
O(nqlogn),不随机的最后20pts是过不了的。
考虑建点分树,减少树的深度。
我们把点分树建出来后,需要建一棵线段树来维护子树中的点到它这里路径当前位异或值为
0
/
1
0/1
0/1的方案数。
我们只需要维护修改,查询与
0
/
1
0/1
0/1翻转这三个操作,
时间复杂度就这样被降到
O
(
n
l
o
g
2
n
)
O\left(nlog^2n\right)
O(nlog2n)。
源码
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<vector>
#include<queue>
using namespace std;
#define MAXN 10005
const int INF=0x7f7f7f7f;
typedef long long LL;
typedef pair<int,int> pii;
template<typename _T>
void read(_T &x){
_T f=1;x=0;char s=getchar();
while('0'>s||s>'9'){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
int n,q,val[MAXN],head[MAXN],tot,dep[MAXN],dfn[20][MAXN];
int siz[20][MAXN],A[20][2],B[20][2],bel[20][MAXN],tim[MAXN];
int all[MAXN],num[MAXN][20],father[MAXN];LL ans;bool vis[MAXN];
struct edge{int to,nxt;}e[MAXN<<1];
void addEdge(int u,int v){e[++tot]=(edge){v,head[u]};head[u]=tot;}
pii operator + (const pii &a,const pii &b){return make_pair(a.first+b.first,a.second+b.second);}
pii operator - (const pii &a,const pii &b){return make_pair(a.first-b.first,a.second-b.second);}
struct segmentTree{
#define lson (rt<<1)
#define rson (rt<<1|1)
int sum[MAXN<<2][2];bool rev[MAXN<<2];
void updata(int rt){
sum[rt][0]=sum[lson][0]+sum[rson][0];
sum[rt][1]=sum[lson][1]+sum[rson][1];
}
void downdata(int rt){
if(!rev[rt])return ;rev[rt]=0;
swap(sum[lson][0],sum[lson][1]);rev[lson]^=1;
swap(sum[rson][0],sum[rson][1]);rev[rson]^=1;
}
void insert(int rt,int l,int r,int ai,int aw){
if(l>r||l>ai||r<ai)return ;int mid=l+r>>1;
if(l==r){sum[rt][aw]++;return ;}
if(ai<=mid)insert(lson,l,mid,ai,aw);
else insert(rson,mid+1,r,ai,aw);
updata(rt);
}
void Reverse(int rt,int l,int r,int al,int ar){
if(l>r||al>r||ar<l)return ;
if(al<=l&&r<=ar){swap(sum[rt][0],sum[rt][1]);rev[rt]^=1;return ;}
int mid=l+r>>1;downdata(rt);
if(al<=mid)Reverse(lson,l,mid,al,ar);
if(ar>mid)Reverse(rson,mid+1,r,al,ar);
updata(rt);
}
pii query(int rt,int l,int r,int al,int ar){
if(l>r||al>r||ar<l)return make_pair(0,0);int mid=l+r>>1;
if(al<=l&&r<=ar)return make_pair(sum[rt][0],sum[rt][1]);
pii res=make_pair(0,0);downdata(rt);
if(al<=mid)res=res+query(lson,l,mid,al,ar);
if(ar>mid)res=res+query(rson,mid+1,r,al,ar);
return res;
}
}T[20][20];
void dosaka(int u,int fa,int p,int x){
int d=dep[p];bel[d][u]=(fa==p)?u:bel[d][fa];
dfn[d][u]=++tim[d];siz[d][u]=1;x^=val[u];
for(int i=0;i<15;i++)B[i][x>>i&1]++,T[d][i].insert(1,1,n,tim[d],x>>i&1);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v]||v==fa)continue;
dosaka(v,u,p,x);siz[d][u]+=siz[d][v];
}
}
int getRoot(int u,int fa,int Sz,int &g){
int sz=1,tmp;bool flag=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v]||v==fa)continue;
sz+=(tmp=getRoot(v,u,Sz,g));flag&=tmp<<1<=Sz;
}
if(flag&&(Sz-sz)<<1<=Sz)g=u;return sz;
}
int sakura(int u,int Sz,int dp){
getRoot(u,0,Sz,u);vis[u]=1;dep[u]=dp;dfn[dp][u]=++tim[dp];siz[dp][u]=1;memset(B,0,sizeof B);
//printf("sakura %d:%d %d %d\n",u,dep[u],dfn[dp][u],siz[dp][u]);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v])continue;memcpy(A,B,sizeof B);dosaka(v,u,u,0);
all[u]+=siz[dp][u]*siz[dp][v];siz[dp][u]+=siz[dp][v];
for(int i=0;i<15;i++){
int x=val[u]>>i&1,tmp=(B[i][0]-A[i][0])*A[i][!x]+(B[i][1]-A[i][1])*A[i][x]+B[i][!x]-A[i][!x];
num[u][i]+=tmp;ans+=(1LL<<i)*tmp;
}
}
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;if(vis[v])continue;
father[sakura(v,siz[dp][v],dp+1)]=u;
}
return u;
}
void solve(int x,int now){
int pre=val[x];val[x]=now;
for(int i=0;i<15;i++){
if((now>>i&1)==(pre>>i&1))continue;
ans-=num[x][i]*(1LL<<i);
num[x][i]=all[x]-num[x][i];
ans+=num[x][i]*(1LL<<i);
}
for(int u=father[x];u;u=father[u])
for(int i=0,d=dep[u];i<15;i++){
if((now>>i&1)==(pre>>i&1))continue;bool t=val[u]>>i&1;pii sx,su;
sx=T[d][i].query(1,1,n,dfn[d][x],dfn[d][x]+siz[d][x]-1);
su=T[d][i].query(1,1,n,dfn[d][u]+1,dfn[d][u]+siz[d][u]-1)-T[d][i].query(1,1,n,dfn[d][bel[d][x]],dfn[d][bel[d][x]]+siz[d][bel[d][x]]-1);
T[d][i].Reverse(1,1,n,dfn[d][x],dfn[d][x]+siz[d][x]-1);
//printf("solve %d %d %d %d %d:%d %d %d %d\n",u,i,dfn[d][u],siz[d][u],bel[d][x],sx.first,sx.second,su.first,su.second);
int dt=su.first*(!t?(sx.first-sx.second):(sx.second-sx.first))+su.second*(!t?(sx.second-sx.first):(sx.first-sx.second))+(!t?(sx.first-sx.second):(sx.second-sx.first));
num[u][i]+=dt;ans+=dt*(1ll<<i);
}
}
int main(){
//freopen("games.in","r",stdin);
//freopen("games.out","w",stdout);
read(n);read(q);for(int i=1;i<=n;i++)read(val[i]);
for(int i=1;i<n;i++){
int u,v;read(u);read(v);
addEdge(u,v);addEdge(v,u);
}
sakura(1,n,0);
while(q--){
int l,r;read(l);read(r);
solve(l,r);printf("%lld\n",ans<<1);
}
return 0;
}