一个数n个节点,m组询问,需回答与a,b距离相等的点的个数。
首先转化为根树,求得a,b的lca,d[i]表示i到节点1的距离,d1=d[a]-d[lca],d2=d[b]-d[lca],d1+d2为奇数时结果为0,偶数时两种情况:
num[i]为以i为根的子树节点数。
1.d1==d2,u=a向上走d1-1步,v=b向上走d2-1步,结果为n-num[u]-num[v];
2.假设d1>d2,u=a向上(d1+d2)/2步,v=a向上(d1+d2)/2-1步,结果为num[u]-num[v];
求a向上x步的节点暴力求会t掉,jump[i][j]表示i向上(1<<j)步的节点。把x看成二进制。
#include<iostream>
#include<string>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<iomanip>
#include<map>
#include<algorithm>
#include<queue>
#include<set>
#define inf 1000000000
#define pi acos(-1.0)
#define eps 1e-8
#define seed 131
using namespace std;
typedef pair<int,int> pii;
typedef unsigned long long ULL;
typedef long long LL;
const int maxn=100005;
int n,m;
int step;
int num[maxn];
vector<int>vec[maxn];
int p[maxn];
int d[maxn*2],deep[maxn*2],first[maxn*2];
int dp[maxn*2][30];
int dist[maxn];
int jump[maxn][36];
void dfs(int u,int fa);
void dfs2(int u,int dep);
int N(int u);
void ST(int n);
int query(int s,int t);
void Jump();
int JumpUp(int x, int y)
{
int ret = x;
for (int i = 0; i <= 18; ++i)
if (y & (1 << i)) ret = jump[ret][i];
return ret;
}
int main()
{
int a,b;
scanf("%d",&n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&a,&b);
vec[a].push_back(b);
vec[b].push_back(a);
}
p[1]=-1;
dist[1]=0;
dfs(1,-1);
for(int i=1;i<=n;i++)
vec[i].clear();
for(int i=2;i<=n;i++)
vec[p[i]].push_back(i);
step=1;
dfs2(1,0);
memset(num,-1,sizeof(num));
N(1);
ST(step-1);
Jump();
scanf("%d",&m);
for(int i=0;i<m;i++)
{
scanf("%d%d",&a,&b);
if(a==b)
{
printf("%d\n",n);
continue;
}
int l=first[a],r=first[b];
if(l>r)
swap(l,r);
int lca=d[query(l,r)];
int d1=dist[a]-dist[lca];
int d2=dist[b]-dist[lca];
if((d1+d2)%2==1)
printf("0\n");
else
{
if(d1==d2)
{
a=JumpUp(a,d1-1);
b=JumpUp(b,d1-1);
printf("%d\n",n-num[a]-num[b]);
}
else
{
if(d1<d2)
swap(a,b);
int v;
v=JumpUp(a,(d1+d2)/2-1);
a=JumpUp(a,(d1+d2)/2);
printf("%d\n",num[a]-num[v]);
}
}
}
return 0;
}
void Jump()
{
for (int i = 1; i <= n; ++i) jump[i][0] = p[i];
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= 18; ++j)
jump[i][j] = jump[jump[i][j - 1]][j - 1];
}
int query(int s,int t)
{
int k=(int)log2(t-s+1.0);
if(deep[dp[s][k]]<deep[dp[t-(1<<k)+1][k]])
return dp[s][k];
return dp[t-(1<<k)+1][k];
}
void ST(int n)
{
for(int i=1;i<=n;i++)
dp[i][0]=i;
for(int j=1;j<30;j++)
{
for(int i=1;i+(1<<j)-1<=n;i++)
{
if(deep[dp[i][j-1]]<deep[dp[i+(1<<j-1)][j-1]])
dp[i][j]=dp[i][j-1];
else
dp[i][j]=dp[i+(1<<j-1)][j-1];
}
}
}
int N(int u)
{
if(num[u]!=-1)
return num[u];
num[u]=1;
int len=vec[u].size();
for(int i=0;i<len;i++)
{
num[u]+=N(vec[u][i]);
}
return num[u];
}
void dfs2(int u,int dep)
{
d[step]=u;deep[step]=dep;first[u]=step++;
int len=vec[u].size();
for(int i=0;i<len;i++)
{
dfs2(vec[u][i],dep+1);
d[step]=u;
deep[step++]=dep;
}
}
void dfs(int u,int fa)
{
int len=vec[u].size();
for(int i=0;i<len;i++)
{
if(vec[u][i]!=fa)
{
dist[vec[u][i]]=dist[u]+1;
dfs(vec[u][i],p[vec[u][i]]=u);
}
}
}