题面:
题意:
给定一棵树。
有m个询问,每个询问给出 a,b 两点,问树上到 a,b 距离相同的点有多少个。
树上两点之间的距离为两点之间的最短距离。
题解:
① a==b 那么 n个点与a,b的距离都相同 ans = n
②若dis(a,b)为奇数,那么没有任何点与a,b的距离相同 ans = 0
③设 lc = lca(a,b),如果dis( lc,a )= dis( lc ,b ),那么除了 a,b所在的 lc 的子树,其他的节点都可以。我们设 a 所在的 lc 的子树的根节点为 aa ( aa 是 lc 的儿子),b 所在的 lc 的子树的根节点为 bb (bb 是 lc 的儿子),那么 ans = n - si [ aa ] - si [ bb ]
④我们设 d [ a ] > d [ b ] ,我们设 x 为中间那个点(dis(a,b)/2),那么在 x 为根子树中,除了 a 所在的子树,其余的节点均可以。我们假设 xx 为 a–>b路径上的第 dis(a,b)/ 2 - 1 的节点,即 xx 是 x 的 a 方向的儿子,那么答案为 ans = si [ x ] - si [ xx ]
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<set>
#define ll long long
#define llu unsigned ll
#define ld long double
#define ui unsigned int
#define pr make_pair
#define pb push_back
#define ui unsigned int
//#define lc (cnt<<1)
//#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define forhead(x) for(int i=head[(x)];i;i=nt[i])
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int maxm=100100;
const int up=100000;
const int hashp=13331;
const int maxn=100100;
int head[maxn],ver[maxn<<1],nt[maxn<<1],tot=1;
int d[maxn],f[maxn][20],si[maxn],t;
int n,m;
void add(int x,int y)
{
ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}
void dfs(int x)
{
si[x]=1;
forhead(x)
{
int y=ver[i];
if(d[y]) continue;
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=t;j++)
f[y][j]=f[f[y][j-1]][j-1];
dfs(y);
si[x]+=si[y];
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(int i=t;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int fi(int x,int de)
{
for(int i=t;i>=0;i--)
{
if(d[f[x][i]]>=de)
x=f[x][i];
}
return x;
}
int ask(int x,int y)
{
if(x==y) return n;
int lc=lca(x,y);
int dis=d[x]+d[y]-2*d[lc];
if(dis%2) return 0;
if(d[x]==d[y])
{
int xx=fi(x,d[lc]+1);
int yy=fi(y,d[lc]+1);
return n-si[xx]-si[yy];
}
if(d[x]<d[y]) swap(x,y);
int xx=fi(x,d[x]-dis/2+1);
return si[f[xx][0]]-si[xx];
}
int main(void)
{
int x,y;
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
t=log2(n)+1;
d[1]=1;
dfs(1);
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
printf("%d\n",ask(x,y));
}
return 0;
}