考虑倍增,然后每次查询就是将四个线性基合并。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<vector>
#include<cstring>
using namespace std;
const int N=20005;
int n,Q,cnt;
int deep[N],head[N],next[N<<1],list[N<<1];
long long q[N<<1];
int fa[N][16];
int ins[61];
vector<long long> v[N][16];
inline long long read()
{
long long a=0,f=1; char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}
while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}
return a*f;
}
inline void insert(int x,int y)
{
next[++cnt]=head[x];
head[x]=cnt;
list[cnt]=y;
}
vector<long long> Union(vector<long long> a,vector<long long> b)
{
vector<long long> ans;
int top=0;
for (int i=0;i<a.size();i++) q[++top]=a[i];
for (int i=0;i<b.size();i++) q[++top]=b[i];
int last=0;
for (int j=60;~j;j--)
{
bool flag=false;
for (int i=last+1;i<=top;i++)
if ((q[i]>>j)&1)
{
swap(q[i],q[++last]);
flag=true;
break;
}
if (!flag) continue;
for (int i=1;i<=top;i++)
if (i!=last)
if ((q[i]>>j)&1) q[i]^=q[last];
}
for (int i=1;i<=last;i++) ans.push_back(q[i]);
return ans;
}
inline void dfs(int x)
{
for (int i=1;(1<<i)<=deep[x];i++)
{
fa[x][i]=fa[fa[x][i-1]][i-1];
v[x][i]=Union(v[x][i-1],v[fa[x][i-1]][i-1]);
}
for (int i=head[x];i;i=next[i])
if (list[i]!=fa[x][0])
{
fa[list[i]][0]=x;
deep[list[i]]=deep[x]+1;
dfs(list[i]);
}
}
inline int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
int t=deep[x]-deep[y];
for (int i=0;(1<<i)<=t;i++)
if ((1<<i)&t) x=fa[x][i];
for (int i=15;~i;i--)
if (fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
inline int UP(int x,int t)
{
for (int i=15;~i;i--)
if ((1<<i)&t) x=fa[x][i];
return x;
}
inline vector<long long> query(int x,int f)
{
int k=log2(deep[x]-deep[f]+1);
return Union(v[x][k],v[UP(x,deep[x]-deep[f]+1-(1<<k))][k]);
}
int main()
{
n=read(); Q=read(); deep[1]=1;
for (int i=1;i<=n;i++) v[i][0].push_back(read());
for (int i=1;i<n;i++)
{
int u=read(),v=read();
insert(u,v); insert(v,u);
}
dfs(1);
while (Q--)
{
int u=read(),v=read(),t=lca(u,v);
vector<long long> ans=Union(query(u,t),query(v,t));
long long ANS=0;
for (int i=0;i<ans.size();i++)
if ((ANS^ans[i])>ANS) ANS^=ans[i];
printf("%lld\n",ANS);
}
return 0;
}