题意:给你n个点,n-1条边,然后有m次查询,问u,v到两点距离相等的点有多少个。
分析:因为只有n-1条边,所以任意两点的路径只有唯一的一条,那么到两点距离相等的点,只有与这个路径的中点相连的那些点,如何求得中点呢?直接求的话肯定超时,我们可以把这个图转化成一个树,用一次dfs遍历建树,然后求得lca即可知道这条路径长度s,u,v两点中层数较深的那点往上走s/2步即为中点。
找到可中点如何求的结果呢?设u为层数较深的那个点,up(x,t)为点x向上走t步。首先路径长度为奇数的肯定无解,因为没有中点存在;u,v在同一层结果是n-son[up(u,s/2-1)] -son[up(v,s/2-1)];否则结果就是son[up(u,s/2)] - son[up(u,s/2-1)];最后有一个坑点那就是u == v,直接输出n。
开始不会写这题,后来参考了别人的代码,学会了用倍增法求lca。
看了这个帖才懂的倍增法:http://www.tuicool.com/articles/N7jQV32
AC代码:
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <vector>
#include <stack>
#include <queue>
using namespace std;
const int maxn = 100007;
const int N = 17;
struct tree{
int p[maxn][N];
int d[maxn];
int son[maxn];
void dfs(int u);
void init();
int up(int x, int t);
int lca(int u, int v);
int result(int u, int v);//结果
};
vector<int> list[maxn];
tree tr;
int n,m;
void tree::dfs(int u){//dfs建树
for(int i = 0; i < list[u].size(); i++){
int v = list[u][i];
if(v != p[u][0]){
d[v] = d[u] + 1;
p[v][0] = u;
for(int j = 1; j < N; j++){
p[v][j] = p[p[v][j-1]][j-1];
}
dfs(v);
son[u] += son[v];
}
}
}
void tree::init(){
for(int i = 1; i <= n; i++){
son[i] = 1;
}
for(int j = 0; j < N; j++){
p[1][j] = 1;
}
d[1] = 1;
dfs(1);
}
int tree::up(int x, int t){//点x往上走t步
for(int i = 0; i < N; i++)
if(t&(1<<i)) x = p[x][i];
return x;
}
int tree::lca(int u, int v){//
if(d[u] > d[v]) u = up(u,d[u]-d[v]);
if(u == v) return u;
for(int i = N-1; i >= 0; i--)
if(p[u][i] != p[v][i]) u = p[u][i], v = p[v][i];
return p[u][0];
}
int tree::result(int u, int v){
if(u == v) return n;
else{
if(d[u] < d[v]) swap(u,v);
int x = lca(u,v);
int s = d[u]-d[x] + d[v]-d[x];//距离
if(s&1) return 0;//无解
if(d[u] == d[v]){
u = up(u,s/2-1);
v = up(v,s/2-1);
return n - son[u] - son[v];
}
else{
u = up(u,s/2-1);
return son[p[u][0]] - son[u];
}
}
}
void input(){
scanf("%d",&n);
int u,v;
for(int i = 0; i < n-1; i++){
scanf("%d%d",&u,&v);
list[u].push_back(v);
list[v].push_back(u);
}
}
void solve(){
tr.init();
scanf("%d",&m);
int u,v;
for(int i = 0; i < m; i++){
scanf("%d%d",&u,&v);
printf("%d\n",tr.result(u,v));
}
}
int main(){
input();
solve();
return 0;
}