题目描述
企鹅国的城市结构是一棵树,有 N 座城市和 N-1 条无向道路,每条道路都一样长。豆豆和豆沙准备去参加 NOIP(National Olympiad in Informatics for Penguin),但是他们住在不同的地方,豆豆住在城市 A ,豆沙住在城市 B 。他们想找一个距离 A 和 B 一样远的集合地点,所以他们想知道有多少个城市满足这个要求?
由于他们会参加很多次 NOIP ,所以有很多个询问。
输入格式:
第一行一个整数 N,代表城市个数。
接下来 N-1 行,每行两个数字 F 和 T ,表示城市 F 和城市 T 之间有一条道路。
接下来一行一个整数 M 代表询问次数。
接下来 M 行,每行两个数字 A 和 B ,表示这次询问的城市 A 和城市 B(A可能与B相同)。
输出格式:
输出 M 行,每行一个整数表示到 A 和 B 一样远的城市个数。
样例输入1:
4
1 2
2 3
2 4
2
1 2
1 3
样例输出1:
0
2
样例输入2:
4
1 2
2 3
2 4
2
1 1
3 3
样例输出2:
4
4
数据范围:
对于 30% 的数据:N,M≤1000;
对于另外 10% 的数据:A=B;
对于另外 30% 的数据:保证树的形态随机;
对于 100% 的数据:1≤N,M≤100000。
题目分析
考试总结:考试时几乎写的是正解,就是最后找中点,脑子短路了,没有就像求lca一样倍增,结果却去dfs,写完才觉得时间复杂度不对。然后就TLE。
分析:我们考虑首先如果询问的两个点就是同一个点,那么答案就是n,这个从样例2就可以看出。对于两个点,要到这两个点的距离相等,只可能是这两个点的路径的中点以及这个中点所相连的所有子树和(除开路径方向的两个子树)。先预处理出每个节点size的大小(它自己再加上它所有子树的大小)。于是我们可以先用lca算出两点路径长度,然后选取深度较深的点让其往上跳len/2,就可以找到中点。统计答案分两种情况:一种是它们的中点就是它们的公共祖先(因为前面已经判断了两个点相同的情况,所以这里可以用深度相等来判断),此时的答案因为包括中点的父亲们,所以用总的n减去在路径上且比中点深度大1的两个点的size的大小;另一种是中点不是它们的公共祖先,那就直接求出在路径上且比中点深度大1的点,用中点的size减去这个点的size就是答案(因为路径的另一个来向就是中点的父亲,是都不计入答案的)。
附代码
#include<iostream>
#include<cstring>
#include<string>
#include<cstdlib>
#include<cstdio>
#include<ctime>
#include<cmath>
#include<cctype>
#include<iomanip>
#include<algorithm>
using namespace std;
const int N=1e5+100;
int tot,n,m,a,b,c,nxt[N*2],first[N],to[N*2],dep[N],fa[N][25];
int sze[N],cha,len,ans,e,f;
int readint()
{
char ch;int i=0,f=1;
for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
if(ch=='-') {ch=getchar();f=-1;}
for(;ch>='0'&&ch<='9';ch=getchar()) i=(i<<3)+(i<<1)+ch-'0';
return i*f;
}
void create(int x,int y)
{
tot++;
nxt[tot]=first[x];
first[x]=tot;
to[tot]=y;
}
void dfs(int u,int f)
{
fa[u][0]=f;
for(int i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=f)
{
dep[v]=dep[u]+1;
dfs(v,u);
sze[u]+=sze[v];
}
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
int d=dep[x]-dep[y];
for(int i=20;i>=0;i--)
if((1<<i)&d) x=fa[x][i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int main()
{
//freopen("equal.in","r",stdin);
//freopen("equal.out","w",stdout);
int x,y;
n=readint();
for(int i=1;i<n;i++)
{
x=readint();y=readint();
create(x,y);
create(y,x);
}
for(int i=1;i<=n;i++) sze[i]=1;
dep[1]=0;
dfs(1,0);
m=readint();
while(m--)
{
a=readint();b=readint();
if(a==b)
{
printf("%d\n",n);
continue;
}
c=lca(a,b);
len=dep[a]+dep[b]-2*dep[c];
if(len%2==1)
{
printf("0\n");
continue;
}
if(dep[a]<dep[b]) swap(a,b);
int d=len/2;
if(dep[a]==dep[b])
{
e=a;f=b;
for(int i=20;i>=0;i--)//用倍增来找点
if((d-1)&(1<<i)) e=fa[e][i];
for(int i=20;i>=0;i--)
if((d-1)&(1<<i)) f=fa[f][i];
ans=n-sze[e]-sze[f];
}
else
{
e=a;f=a;
for(int i=20;i>=0;i--)
if(d&(1<<i)) e=fa[e][i];
for(int i=20;i>=0;i--)
if((d-1)&(1<<i)) f=fa[f][i];
ans=sze[e]-sze[f];
}
cout<<ans<<endl;//printf("%d\n",ans);
}
return 0;
}