题目描述
如题,给定一棵有根多叉树,请求出指定两个点直接最近的公共祖先。
输入输出格式
输入格式:
第一行包含三个正整数N、M、S,分别表示树的结点个数、询问的个数和树根结点的序号。
接下来N-1行每行包含两个正整数x、y,表示x结点和y结点之间有一条直接连接的边(数据保证可以构成树)。
接下来M行每行包含两个正整数a、b,表示询问a结点和b结点的最近公共祖先。
输出格式:
输出包含M行,每行包含一个正整数,依次为每一个询问的结果。
输入输出样例
输入样例#1: 复制
5 5 4 3 1 2 4 5 1 1 4 2 4 3 2 3 5 1 2 4 5
输出样例#1: 复制
4 4 1 4 4
说明
时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=10,M<=10
对于70%的数据:N<=10000,M<=10000
对于100%的数据:N<=500000,M<=500000
样例说明:
该树结构如下:
第一次询问:2、4的最近公共祖先,故为4。
第二次询问:3、2的最近公共祖先,故为4。
第三次询问:3、5的最近公共祖先,故为1。
第四次询问:1、2的最近公共祖先,故为4。
第五次询问:4、5的最近公共祖先,故为4。
故输出依次为4、4、1、4、4。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#include<map>
#include<vector>
#include<set>
#include<map>
#include<queue>
#include<deque>
#include<cstdlib>
#define NN 1000000
#define M 45
using namespace std;
int grand[NN][M];
int gw[NN][M];
int depth[NN];
int list[NN];
int n,m,k,root;
int x,y;
int sum;
int N;
typedef struct
{
int nxt;
int to;
int length;
}edge;
edge p[NN];
inline void add(int x,int y,int z)
{
p[++sum].nxt=list[x];
p[sum].to=y;
p[sum].length=z;
list[x]=sum;
p[++sum].nxt=list[y];
p[sum].to=x;
p[sum].length=z;
list[y]=sum;
}
inline void dfs(int x)
{
for(int i=1;i<=N;i++)
{
grand[x][i]=grand[grand[x][i-1]][i-1];
gw[x][i]=gw[x][i-1]+gw[grand[x][i-1]][i-1];
}
for(int i=list[x];i!=0;i=p[i].nxt)
{
int to=p[i].to;
if(to!=grand[x][0])
{
depth[to]=depth[x]+1;
grand[to][0]=x;
gw[to][0]=p[x].length;
dfs(to);
}
}
}
inline void init()
{
//N=floor(log(n + 0.0) / log(2.0));//最多能跳的2^i祖先
N=20;
depth[root]=0;
memset(grand,0,sizeof(grand));
memset(gw,0,sizeof(gw));
dfs(root);
}
inline int lca(int a,int b)
{
if(depth[a]>depth[b])swap(a,b);
int ans=0;
for(int i=N;i>=0;i--)
{
if(depth[a]<depth[b]&&depth[grand[b][i]]>=depth[a]&&grand[b][i]!=0)
{
ans+=gw[b][i];
b=grand[b][i];
}
}
for(int i=N;i>=0;i--)
{
if(grand[a][i]!=grand[b][i])
{
ans+=gw[a][i];
ans+=gw[b][i];
a=grand[a][i];
b=grand[b][i];
}
}
if(a!=b)
{
ans+=gw[a][0],ans+=gw[b][0];
}
if(a==b)return a;
return grand[a][0];
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
root=k;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y,1);
}
init();
for(int i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
}