Description
给出一棵 n n 个节点的树,第个节点点权为 ai a i ,有 q q 次查询,每次给出三个整数,假设 a a 到的树上简单路径编号为 t0,t1,....,tm t 0 , t 1 , . . . . , t m ,要查询 t0,tk,...,tpk(pk≤m) t 0 , t k , . . . , t p k ( p k ≤ m ) 这些点的点权异或和
Input
多组用例,每组用例输入两个整数 n,q n , q 表示点数和查询数,之后 n−1 n − 1 行每行输入两个整数 u,v u , v 表示一条树边,之后输入 n n 个整数表示第 i i 个点的点权,最后行每行输入三个整数 a,b,k a , b , k 表示一组查询 (∑n≤5⋅104,∑q≤5⋅105,1≤ai≤109,1≤a,b,k≤n) ( ∑ n ≤ 5 ⋅ 10 4 , ∑ q ≤ 5 ⋅ 10 5 , 1 ≤ a i ≤ 10 9 , 1 ≤ a , b , k ≤ n )
Output
对于每组用例,输出查询结果
Sample Input
5 6
1 5
4 1
2 1
3 2
19
26
0
8
17
5 5 1
1 3 2
3 2 1
5 4 2
3 4 4
1 4 5
Sample Output
17
19
26
25
0
19
Solution
在线倍增求 LCA L C A 和快速找到一个点的第 x x 级祖先,当较大时,被计算的点不是很多,直接从 a a 开始往上找其级祖先(不要越过两点的 LCA L C A ),然后从 b b 开始往上找即可,注意从开始时要去掉路径末端多余点的影响,当 k k 较小时,预处理从一个点到根节点每次跳个点的点权异或和,用预处理的结果快速得到答案即可,具体操作见代码
Code
v#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
const int maxn=50005;
int n,q,val[maxn],p[maxn][17],deep[maxn],f[maxn][81],dp[maxn][81];
vector<int>g[maxn];
void dfs(int u,int fa)
{
p[u][0]=fa;
for(int i=1;i<=16;i++)p[u][i]=p[p[u][i-1]][i-1];
f[u][1]=fa;
for(int i=2;i<=80;i++)f[u][i]=f[f[u][i-1]][1];
for(int i=1;i<=80;i++)dp[u][i]=dp[f[u][i]][i]^val[u];
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
deep[v]=deep[u]+1;
dfs(v,u);
}
}
int lca(int a,int b)
{
int i,j;
if(deep[a]<deep[b])swap(a,b);
for(i=0;(1<<i)<=deep[a];i++);
i--;
for(j=i;j>=0;j--)
if(deep[a]-(1<<j)>=deep[b])
a=p[a][j];
if(a==b) return a;
for(j=i;j>=0;j--)
{
if(p[a][j]&&p[a][j]!=p[b][j])
{
a=p[a][j];
b=p[b][j];
}
}
return p[a][0];
}
int find(int x,int step)//找到x第step级祖先
{
for(int i=0;i<=16;i++)
if(step&(1<<i))
x=p[x][i];
return x;
}
int Solve1(int a,int b,int k)
{
int c=lca(a,b);
if(a==b)return val[a];
else
{
int num=deep[a]+deep[b]-2*deep[c];
b=find(b,num%k);
if(deep[b]<deep[c])
{
int x=(deep[a]-deep[c])/k+1;
b=find(a,x*k);
return dp[a][k]^dp[b][k];
}
if((deep[c]-deep[b])%k==0)
{
return dp[a][k]^dp[b][k]^val[c];
}
int x=(deep[a]-deep[c])/k+1;
int y=(deep[b]-deep[c])/k+1;
int aa=find(a,x*k),bb=find(b,y*k);
return dp[a][k]^dp[b][k]^dp[aa][k]^dp[bb][k];
}
}
int Solve2(int a,int b,int k)
{
int c=lca(a,b);
int ans=val[a];
int now=a;
while(deep[now]>=deep[c])
{
if(deep[now]-deep[c]<k)break;
now=find(now,k);
ans^=val[now];
}
if(deep[now]+deep[b]-2*deep[c]>=k)
{
now=find(b,(deep[now]+deep[b]-2*deep[c])%k);
ans^=val[now];
}
else return ans;
while(deep[now]>deep[c])
{
if(deep[now]-deep[c]<=k)break;
now=find(now,k);
if(now==c)break;
ans^=val[now];
}
return ans;
}
int main()
{
while(~scanf("%d%d",&n,&q))
{
for(int i=1;i<=n;i++)g[i].clear();
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v),g[v].push_back(u);
}
for(int i=1;i<=n;i++)scanf("%d",&val[i]);
memset(f,0,sizeof(f));
memset(p,0,sizeof(p));
deep[0]=0;
deep[1]=1;
dfs(1,0);
while(q--)
{
int a,b,k;
scanf("%d%d%d",&a,&b,&k);
if(k<=80)printf("%d\n",Solve1(a,b,k));
else printf("%d\n",Solve2(a,b,k));
}
}
return 0;
}