题目链接
题意:求树中两个节点的最近公共祖先。
一、
先深搜记录每个结点的父亲节点。
先将两个节点中的一个节点往上遍历(寻找父亲节点)直到源节点,并标记遍历过的节点,再遍历另一个节点,若遇到的节点已经遍历过,说明该节点为最近公共祖先。
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
const int maxn=1e5+10;
int n,m,s;
int p[maxn],vis[maxn];
vector<int> g[maxn];
void dfs(int u)
{
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(v==p[u])
continue;
p[v]=u;
dfs(v);
}
}
int lca(int x,int y)
{
memset(vis,0,sizeof(vis) );
while(x)
{
vis[x]=1;
x=p[x];
}
while(!vis[y])
y=p[y];
return y;
}
int main()
{
scanf("%d%d%d",&n,&m,&s);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(s);
while(m--)
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}
二、
深搜的时候记录每个节点的深度。
首先将两个节点置于同一深度,再将两个节点同时向上遍历,若遍历到同一节点,则说明该节点为最近公共祖先。
#include <iostream>
#include <vector>
#include <cstring>
#include <cstdio>
#include <cstdlib>
using namespace std;
const int maxn=1e5+10;
int n,m,s;
int p[maxn],dep[maxn];
vector<int> g[maxn];
void dfs(int u,int d)
{
dep[u]=d;
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(v==p[u])
continue;
p[v]=u;
dfs(v,d+1);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
while(dep[x]>dep[y])
x=p[x];
while(x!=y)
{
x=p[x];
y=p[y];
}
return x;
}
int main()
{
scanf("%d%d%d",&n,&m,&s);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(s,0);
while(m--)
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}
三、
与二相同,只是往上遍历的时候,是倍增的遍历。
#include <iostream>
#include <vector>
#include <cstring>
#include <math.h>
#include <cstdio>
#include <cstdlib>
using namespace std;
const int maxn=1e6+10;
int n,m,s;
int f[maxn][25],dep[maxn];
vector<int> g[maxn];
void dfs(int u,int p,int d)
{
f[u][0]=p;
dep[u]=d;
int len=g[u].size();
for(int i=0;i<len;i++)
{
int v=g[u][i];
if(v==p)
continue;
dfs(v,u,d+1);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
for(int i=log2(dep[x]-dep[y]);i>=0;i--)
if((1<<i)<=dep[x]-dep[y])
x=f[x][i];
if(x==y)
return x;
for(int i=log2(dep[x]);i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
int main()
{
scanf("%d%d%d",&n,&m,&s);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
g[x].push_back(y);
g[y].push_back(x);
}
dfs(s,s,0);
for(int j=1;j<25;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
while(m--)
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}