将路径看作序列,求一个序列的最大异或和显然可以考虑求出它的异或线性基。
如果对于每个询问 (x,y) 的 lca 都求出它到子树的每个点的路径的异或线性基,显然时间和空间无法接受,考虑点分治。
对于每个重心
x
处理所有路径经过
通过统计每个子树的大小,将每个子树对应的询问放在一段连续的区间里面进行下一层分治。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=20005;
const int M=N*10;
int n,qn,num,mnsz,core,ed=1,g[N],sz[N],tag[N],cnt[N],st[N],en[N];
ll a[N],ans[M],f[N][61],tmp[61];
struct E{int v,nxt,cut;}e[N<<1];
struct Q{int x,y,k;}q[M],qtmp[M];
inline void adde(int x,int y){
e[++ed].v=y;e[ed].nxt=g[x];g[x]=ed;
}
void findroot(int x,int y){
int t=0;sz[x]=1;
for(int i=g[x];i;i=e[i].nxt)if(!e[i].cut&&e[i].v!=y){
findroot(e[i].v,x);
if(sz[e[i].v]>t)t=sz[e[i].v];
sz[x]+=sz[e[i].v];
}
if(num-sz[x]>t)t=num-sz[x];
if(t<mnsz)mnsz=t,core=x;
}
void dfs(int x,int y,int z){
int i,j;ll t=a[x];tag[x]=z;sz[x]=1;
for(i=60;~i;--i)f[x][i]=f[y][i];
for(i=60;~i;--i)if((t>>i)&1LL){
if(f[x][i])t^=f[x][i];
else {f[x][i]=t;break;}
}
for(int i=g[x];i;i=e[i].nxt)if(!e[i].cut&&e[i].v!=y){
dfs(e[i].v,x,z);sz[x]+=sz[e[i].v];
}
}
ll merge(int x,int y){
int i,j;ll t;
for(i=60;~i;--i)tmp[i]=f[x][i];
for(i=60;~i;--i)if(f[y][i]){
t=f[y][i];
for(j=i;~j;--j)if((t>>j)&1LL){
if(tmp[j])t^=tmp[j];
else {tmp[j]=t;break;}
}
}
ll res=0;
for(i=60;~i;--i){
if((res^tmp[i])>res)res^=tmp[i];
}
return res;
}
void solve(int x,int l,int r){
//if(l>r)return;
int i;
for(i=60;~i;--i)f[x][i]=0;
for(i=60;~i;--i)if((a[x]>>i)&1LL){f[x][i]=a[x];break;}
tag[x]=x;cnt[x]=0;
for(i=g[x];i;i=e[i].nxt)if(!e[i].cut){
cnt[e[i].v]=0;
dfs(e[i].v,x,e[i].v);
}
for(i=l;i<=r;++i){
if(tag[q[i].x]!=tag[q[i].y]||tag[q[i].x]==x)cnt[x]++;
else cnt[tag[q[i].x]]++;
}
st[x]=l;en[x]=l-1;
int t=l+cnt[x];
for(i=g[x];i;i=e[i].nxt)if(!e[i].cut){
st[e[i].v]=t;
en[e[i].v]=t-1;
t+=cnt[e[i].v];
}
for(i=l;i<=r;++i){
if(tag[q[i].x]!=tag[q[i].y]||tag[q[i].x]==x)qtmp[++en[x]]=q[i];
else qtmp[++en[tag[q[i].x]]]=q[i];
}
for(i=l;i<=r;++i)q[i]=qtmp[i];
for(i=st[x];i<=en[x];++i){
ans[q[i].k]=merge(q[i].x,q[i].y);
}
for(i=g[x];i;i=e[i].nxt)if(!e[i].cut){
if(st[e[i].v]>en[e[i].v])continue;
e[i].cut=e[i^1].cut=1;
num=mnsz=sz[e[i].v];
core=e[i].v;
findroot(e[i].v,0);
solve(core,st[e[i].v],en[e[i].v]);
}
}
int main(){
int i,j,x,y,z;
scanf("%d%d",&n,&qn);
for(i=1;i<=n;++i)scanf("%lld",&a[i]);
for(i=1;i<n;++i)scanf("%d%d",&x,&y),adde(x,y),adde(y,x);
for(i=1;i<=qn;++i)scanf("%d%d",&q[i].x,&q[i].y),q[i].k=i;
num=mnsz=n;core=1;
findroot(1,0);
solve(core,1,qn);
for(i=1;i<=qn;++i)printf("%lld\n",ans[i]);
return 0;
}