1.离线Tarjan
设我们求点对(u,v)的最近公共祖先。
利用在DFS过程中,从点u第一次到点v的过程中,必定是从u开始,经过u和v的最近公共祖先的S,然后到达v的,这时候u和v都是在S为根结点的子树里的。
如果我们在访问v的时候,u已经被访问过了,这时候,如果我们知道u在以哪个结点S为根结点的子树里,我们就知道它们的公共祖先了,即S。借助并查集,在DFS过程中,我们每到达一个节点x,便创建一棵以x为根结点的子树(即在并查集中令fa[x]=x),将不断将它的子结点连同子结点的子结点……合并到这棵子树下(即借助并查集的并操作令fa[son of x]=x),注意必须在子结点结束DFS过程后才合并,因为我们查询的点对(u,v)也有可能在以结点sone of x为根的子树下。则我们访问到u或者v的时候,判断v或者u是否已经被访问过,如果被访问过,那么v或者u所在的子树的根结点S即是它们的公共祖先,查找v或u所在子树的根节点借助并查集的查操作(即可即find_fa(v或者u))。
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
const int N=10005;
int q1,q2,res,fa[N],deg[N];
int find_fa(int u)
{
if(fa[u]==u) return u;
return fa[u]=find_fa(fa[u]);
}
struct Edge
{
int to;
Edge *next;
}memo[N*2],*cur,*adj[N];
void addEdge(Edge *head[],int u,int v)
{
cur->to=v;
cur->next=head[u];
head[u]=cur++;
}
void tarjan(int u)
{
fa[u]=u;
for(Edge *it=adj[u];it;it=it->next)
{
int to=it->to;
tarjan(to);
fa[to]=u;
}
if(u==q1||u==q2)
{
if(u!=q1) swap(q1,q2);
if(fa[q2]) res=find_fa(fa[q2]);
}
}
void init(int n)
{
for(int i=0;i<=n;i++)
{
adj[i]=NULL;
deg[i]=0;fa[i]=0;
}
cur=memo;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n,u,v;
scanf("%d",&n);
init(n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
addEdge(adj,u,v); deg[v]++;
}
scanf("%d%d",&q1,&q2);
for(int i=1;i<=n;i++) if(deg[i]==0) {tarjan(i);break;}
printf("%d\n",res);
}
return 0;
}
2.倍增法
倍增法先利用一次dfs处理出每个结点的深度及其每个结点i的2^j层的祖先是谁——也可以表述成从结点i向上跳2*j层(即dp[i][j])。
对于询问u,v的公共祖先,假设u的深度比v小,先将u跳到与v同一层的某个u的祖先u1,这时候,如果u1与v是同一个点,那么它们的公共祖先就是这u1,即也是v。
否则,这时候,我们就要在树上向上跳若干层,使得u和v第一次交汇,也即得到u的某个祖先u2和v的某个祖先v2,满足u2=v2,那么那个交汇的点就是它们的最近公共祖先。如果我们跳到这个交汇点的下一层的结点,设为u3和v3,即满足dp[u3][0]=u2。注意到,从从u跳到u3的过程中,始终满足dp[u’][i]!=dp[v’][i]。
设u到u3的层数为x,那么这个x是可以表达成一个二进制数的,即x=2^k1+2^k2+..,即我们要从u达到u3要跳2^k1,2^k2…层——不就是我们处理出来的dp函数?那么,我们寻找满足dp[u’][k]!=dp[v’][k]的最大k值,然后跳到那一层(即令u’’=dp[u’][k],v’’=dp[v’][k]),然后再寻找满足dp[u’’][k]!=dp[v’’][k]的最大k值,再跳到那一层。。。这样到最后我们就到了u3层。如果还不理解可以画个图帮助理解下。
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N=10005;
const int Log=20;
int dp[N][Log],depth[N],deg[N];
struct Edge
{
int to;
Edge *next;
}memo[N*2],*cur,*head[N];
void addEdge(int u,int v)
{
cur->to=v;
cur->next=head[u];
head[u]=cur++;
}
void dfs(int u)
{
depth[u]=depth[dp[u][0]]+1;
for(int i=1;i<Log;i++) dp[u][i]=dp[dp[u][i-1]][i-1];
for(Edge *it=head[u];it;it=it->next)
{
dfs(it->to);
}
}
int lca(int u,int v)
{
if(depth[u]<depth[v]) swap(u,v);
for(int st=1<<(Log-1),i=Log-1;i>=0;i--,st>>=1)
{
if(st<=depth[u]-depth[v])
{
u=dp[u][i];
}
}
if(u==v) return u;
for(int i=Log-1;i>=0;i--)
{
if(dp[v][i]!=dp[u][i])
{
v=dp[v][i];
u=dp[u][i];
}
}
return dp[u][0];
}
void init(int n)
{
for(int i=0;i<=n;i++)
{
dp[i][0]=0;
head[i]=NULL;
deg[i]=0;
}
cur=memo;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n,u,v;
scanf("%d",&n);
init(n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
addEdge(u,v);
deg[v]++;
dp[v][0]=u;
}
for(int i=1;i<=n;i++) if(deg[i]==0) {dfs(i);break;}
scanf("%d%d",&u,&v);
printf("%d\n",lca(u,v));
}
return 0;
}
3. 转化成RMQ
RMQ:区间最小值询问问题。
RMQ(A,i,j):对于线性序列A中,询问区间[i,j]上的最小值。
ST(Sparse Table)算法是一个非常有名的在线处理RMQ问题的算法,它可以在O(logN)时间内进行预处理,然后在O(1)的时候内回答每个查询。首先是预处理,用动态规划DP解决。设A[i]是要求区间最值的数列,dp[i,j]表示从第i个数起连续2^j个数中的最大值。例如数列3 2 4 5 6 8 1 2 9 7,dp[1,0]表示从第1个数起,长度为2^0=1的最大值,其实就是3这个数。dp[1,2]=5,dp[1,3]=8,dp[2,0]=2,dp[2,1]=4……从这里可以看出dp[i,0]其实就等于A[i]。这样,DP的状态,初值都已经有了,剩下的就是状态转移方程。我们把dp[i,j]平均分成两段(因为dp[i,j]一定是偶数个数字),从i+2^(j-1)-1为一段,从i+2^(j-1)到i+2^j-1为一段(长度都为2^(j-1))。用上例说明,当i=1,j=3时就是3 2 4 5 和 6 8 1 2 这两段。F[i,j]就是这两段的最大值中的最大值。于是我们得到了动态规划方程dp[i,j]=max(dp[i,j-1],dp[i+2^(j-1),2^(j-1)。
然后是查询。取k=[log2(j-i+1)],则有:RMQ(A,i,j)=min(dp[i,k],dp[j-2^k+1,k])。举例说明,要求区间[2,8]的最大值,就要把它分成[2,5]和[5,8]两个区间,因为这两个区间的最大值我们可以直接由dp[2,2]和dp[5,2]得到。
对有根树T进行DFS,将遍历到的结点按照顺序记下,我们将得到一个长度为2N – 1的序列,称之为T的欧拉序列F。
每个结点都在欧拉序列中出现,我们记录结点u在欧拉序列中第一次出现的位置为pos(u)。
下图是一个例子:
根据DFS的性质,对于两结点u、v,从pos(u)遍历到pos(v)的过程中经过LCA(u,v)有且仅有一次,且深度是深度序列B[pos(u)…pos(v)]中最小的。
即LCA(T, u, v) =RMQ(B, pos(u), pos(v))
有两种写法。第一 种:在rmq_depth里记录在如上图所示的深度序列,在rmq_hash里记录每个深度序列里每个位置对应的点,dp[i][j]表示从i开始的长度为2^j的序列中深度最小的值的位置。比如查询(u,v)的最近公共祖先,即是rmq_hash[min(dp[u][k],dp[v-(1<<k)+1][k])];
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int N=10005;
const int Log=30;
int deg[N],dp[N*2][Log];
int dfn;//dfs过程中用到的时间戳
int rmq_pos[N];//表示节点u第一次出现的位置
int rmq_depth[N*2];//表示路径上的每个点的深度
int rmq_hash[N*2];//表示路径上的每个深度代表的点
struct Edge
{
int to;
Edge *next;
}memo[N*2],*cur,*adj[N];
void addEdge(Edge *head[],int u,int v)
{
cur->to=v;
cur->next=head[u];
head[u]=cur++;
}
void dfs(int u,int d)
{
rmq_pos[u]=dfn;
rmq_depth[dfn]=d;
rmq_hash[dfn++]=u;
for(Edge *it=adj[u];it;it=it->next)
{
int v=it->to;
dfs(v,d+1);
rmq_depth[dfn]=d;
rmq_hash[dfn++]=u;
}
}
void solve(int n)
{
for(int i=1;i<=n;i++) dp[i][0]=i;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
{
int tmp1=i,tmp2=i+(1<<(j-1));
if(rmq_depth[dp[tmp1][j-1]]<rmq_depth[dp[tmp2][j-1]])
dp[i][j]=dp[tmp1][j-1];
else dp[i][j]=dp[tmp2][j-1];
}
}
int rmq(int u,int v)
{
u=rmq_pos[u],v=rmq_pos[v];
if(u>v) swap(u,v);
int k=(int)(log(v*1.0-u+1)/log(2.0));
int tmp1=u,tmp2=v-(1<<k)+1;
return rmq_hash[min(dp[tmp1][k],dp[tmp2][k])];
}
void init(int n)
{
dfn=1;
for(int i=0;i<=n;i++)
{
adj[i]=NULL;
deg[i]=0;
}
cur=memo;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n,u,v;
scanf("%d",&n);
init(n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
addEdge(adj,u,v); deg[v]++;
}
for(int i=1;i<=n;i++) if(deg[i]==0){dfs(i,0);break;}
solve(2*n-1);
scanf("%d%d",&u,&v);
printf("%d\n",rmq(u,v));
}
return 0;
}
第二种:对第一种写法做了些改进,即如果在一棵子树下,深度值较大的同时编号也较大,那么在对dp数组进行处理时就可以不用借助上面的rmq_depth了。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int N=10005;
const int Log=30;
int deg[N],dp[N*2][Log];
int dfn;
int rmq_low[N*2];
int rmq_pos[N*2];
int rmq_hash[N*2];
struct Edge
{
int to;
Edge *next;
}memo[N*2],*cur,*adj[N];
void addEdge(Edge *head[],int u,int v)
{
cur->to=v;
cur->next=head[u];
head[u]=cur++;
}
void dfs(int u)
{
int tmp=dfn;
rmq_low[dfn]=dfn;
rmq_pos[u]=dfn;
rmq_hash[dfn++]=u;
for(Edge *it=adj[u];it;it=it->next)
{
int v=it->to;
dfs(v);
rmq_low[dfn]=tmp;
rmq_hash[dfn++]=u;
}
}
void solve(int n)
{
for(int i=1;i<=n;i++) dp[i][0]=rmq_low[i];
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
dp[i][j]=min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
int rmq(int u,int v)
{
u=rmq_pos[u],v=rmq_pos[v];
if(u>v) swap(u,v);
int k=(int)(log(v*1.0-u+1)/log(2.0));
int tmp1=u,tmp2=v-(1<<k)+1;
return rmq_hash[min(dp[tmp1][k],dp[tmp2][k])];
}
void init(int n)
{
dfn=1;
for(int i=0;i<=n;i++)
{
deg[i]=0;
adj[i]=NULL;
}
cur=memo;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
int n,u,v;
scanf("%d",&n);
init(n);
for(int i=0;i<n-1;i++)
{
scanf("%d%d",&u,&v);
addEdge(adj,u,v); deg[v]++;
}
for(int i=1;i<=n;i++) if(deg[i]==0){dfs(i);break;}
solve(2*n-1);
scanf("%d%d",&u,&v);
printf("%d\n",rmq(u,v));
}
return 0;
}