算法思想
最近公共祖先指有根树中距离两节点最近的公共祖先,祖先为当前节点早树根路径上的所有节点(包括自己),可以用LCA求解树上任意两点u,v的距离,公式为dist[u]+dist[v]-2×dist[lca]
暴力
1.向上标记法
从u向上遍历,一直到根节点,标记所有经过的节点,从v开始遍历,若v标记,v为LCA,当遇到第一次标记的节点,该节点为LCA
2.同步前进法
将u和v中的深度大的节点上升到与另一个相同的深度,然后一同前进,走到同一个节点时便得到LCA
树上倍增
树上倍增法为加上倍增思想的同步前进法,同样使用了ST表,F[i,j]表示i的2j辈祖先,F[i,0]为父节点,采用分治的思想,对于F[i,j],分成两个步骤,i节点先向根节点走2j-1步到F[i,j-1],之后再走2j-1步得到 F [ i , j ] = F [ F [ i , j − 1 ] , j − 1 ] F[i,j]=F[F[i,j-1],j-1] F[i,j]=F[F[i,j−1],j−1]
对于问题的求解,首先需要将u,v上升到同一深度,按照增量递减,若到达的节点深度比另一个深度小,不操作,到达的节点深度大于等于另一个深度,该点上移,到达同一深度后,同样按照增量递减的思路,到达节点相同,不操作,到达节点不同,上移,直到增量为0
实现代码如下
void ST()
{
for(int j=1;j<=k;j++)//k为最大深度的log2值
for(int i=1;i<=n;i++)
F[i][j]=F[F[i][j-1]][j-1];
}
int LCA(int x,int y)
{
if(d[x]>d[y])
swap(x,y);
for(int i=k;i>=0;i--)
if(d[F[y][i]]>=d[x])
y=F[y][i];
if(x==y)
return x;
for(int i=k;i>=0;i--)
if(F[y][i]!=F[x][i])
y=F[y][i],x=F[x][i];
return F[y][0];
}
在线RMQ
欧拉序列:DFS过程中依次将经过的节点记录,同样也记录回溯经过的节点,一个节点可能被记录多次,相当于从树根开始一笔画出一个经过所有节点的回路,其实就是前序搜索树
两个节点的LCA一定是两节点(首次出现下标间区间)间欧拉序列中深度最小节点,如下图,DFS序列为1 2 4 6 8 6 9 6 4 2 5 7 5 2 1 3 1,6和5首次出现下标为4和11,深度最小节点为2
在线RMQ首先要构造出欧拉序列,代码如下
void dfs(int u,int d)
{
vis[u]=1;
pos[u]=++tot;//记录首次出现的下标,从0开始记录,对于每个大小只会记录一次
seq[tot]=u;//记录欧拉序列
dep[tot]=d;//记录该点深度
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to,w=e[i].c;
if(vis[v]) continue;
dist[v]=dist[u]+w;
dfs(v,d+1);
seq[++tot]=u;//对回溯的节点的操作,回溯节点无pos记录
dep[tot]=d;//同上
}
}
创造完欧拉序列后,由于最后求解最小深度,因此对深度的数组创建区间最值查询ST,F(i,j)表示[i,i+2j-1]区间深度最小的节点下标
void ST()
{
for(int i=1;i<=tot;i++)
F[i][0]=i;//记录下标,不是深度最小值
int k=log2(tot);
for(int j=1;j<=k;j++)
for(int i=1;i<=n-(1<<j)+1;i++)
F[i][j]=dep[F[i][j-1]]<dep[F[i+(1<<(j-1))][j-1]]?F[i][j-1]:F[i+(1<<(j-1))][j-1];
}
查询代码如下
int RMQ(int l,int r)
{
int k=log2(r-l+1);
if(dep[F[l][k]]<dep[F[r-(1<<k)+1][k]])
return F[l][k];
else
return F[r-(1<<k)+1][k];//注意,返回的是欧拉序列的下标
}
int LCA(int x,int y)
{
int l=pos[x],r=pos[y];
if(l>r) swap(l,r);
return seq[RMQ(l,r)];//返回节点号
}
Tarjan
Tarjan是离线算法,读入所有查询,然后运行程序一次性得到所有查询答案,步骤如下
- 初始化集合号数组和访问数组,fa[i]=i,vis[i]=0
- 从节点u出发DFS,标记vis[u]=1,DFSu所有未访问的邻接点,遍历过程中更新距离,回退时更新集合号
- u的邻接点全部遍历完,检查u的所有查询,若存在一个查询u、v,vis[v]=1,利用并查集找到v的祖宗,找到的节点即u、v的最近公共祖先
当前节点u的邻接点已访问完毕时,检查u相关的所有查询v,若vis[v]≠1,不操作,否则用并查集查找v的祖宗,lca(u,v)=fa[v]。u的祖宗就是u向上查找第一个邻接点未访问完的节点,其fa[]未更新,仍满足fa[i]=i,它为v的祖宗
实现代码如下
int find(int x)
{
if(x!=fa[x])
return x;
return fa[x]=find(fa[x]);
}
void tarjan(int u)
{
vis[u]=1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to,w=e[i].c;
if(vis) continue;
dis[v]=dis[u]+w;
tarjan(v);
fa[v]=u;
}
for(int i=0;i<query[u].size();i++)//u相关的所有查询,运行到这里的时候u的子树都已经查完了
{
int v=query[u][i];
int id=query_id[u][i];//记录查询结果,获得查询编号
if(vis[v])
{
int lca=find(v);
ans[id]=dis[u]+dis[v]-2*dis[lca];//uv间距
}
}
}
训练
POJ1330
题目大意:给出一棵树,对于给定查询输出两个节点的LCA
思路:由于所给数据量小,因此可以用暴力(只有一个查询)
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
using namespace std;
int T,N,fa[121212],u,v;
bool vis[121212];
int main() {
scanf("%d",&T);
while(T--) {
scanf("%d",&N);
for(int i=1; i<=N; i++)
fa[i]=i;
N--;//数据是N-1
while(N--) {
scanf("%d%d",&u,&v);
fa[v]=u;
}
scanf("%d%d",&u,&v);
vis[u]=1;
while(fa[u]!=u) {
u=fa[u];
vis[u]=1;
}
if(vis[v]) {//如果已经遍历过了
printf("%d\n",v);
continue;
}
while(fa[v]!=v) {
v=fa[v];
if(vis[v]) {
printf("%d\n",v);
break;
}
}
memset(vis,0,sizeof(vis));
}
return 0;
}
HDU2586
题目大意:n个点,点间有一些双向道路,没有两条路连接同一个节点,每条路有自己的长度,给出多个查询,询问从A点到B点需要走多远
思路:树上倍增+ST
代码
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int T,head[400001],cnt,F[400001][30],dis[400001],d[400001],k,n,m;//注意空间大小
struct node {
int to,l,next;
} edge[400001];
void Add(int from,int to,int w) {
edge[++cnt].to=to;
edge[cnt].l=w;
edge[cnt].next=head[from];
head[from]=cnt;
}//链式前向星存边
void DFS(int u) {
for(int i=head[u]; i; i=edge[i].next) {
int v=edge[i].to;
if(F[u][0]==v)
continue;
d[v]=d[u]+1;
k=max(d[v],k);
dis[v]=dis[u]+edge[i].l;
F[v][0]=u;
DFS(v);
}
}
void ST() {
for(int j=1; j<=k; j++)
for(int i=1; i<=n; i++)
F[i][j]=F[F[i][j-1]][j-1];//构造ST
}
int LCA(int x,int y) {
if(d[x]>d[y])
swap(x,y);
for(int i=k; i>=0; i--)//如果y的深度大
if(d[F[y][i]]>=d[x])
y=F[y][i];
if(x==y)
return x;
for(int i=k; i>=0; i--)//此时x,y的深度相同
if(F[x][i]!=F[y][i])
x=F[x][i],y=F[y][i];
return F[x][0];
}
int main() {
scanf("%d",&T);
while(T--) {
scanf("%d%d",&n,&m);
for(int i=1; i<n; i++) {
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
Add(a,b,c);
Add(b,a,c);
}
DFS(1);
k=log2(k);
ST();
while(m--) {
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",dis[x]+dis[y]-2*dis[LCA(x,y)]);//求距离公式
}
memset(dis,0,sizeof(dis));
memset(d,0,sizeof(d));
memset(head,0,sizeof(head));
for(int i=1; i<=cnt; i++)
edge[i].l=edge[i].next=edge[i].to=0;
cnt=0;
k=0;
}
return 0;
}
POJ1986
题目大意:N个点,给出M条连接两个点之间的边,给出边的长度,给出K个查询,查询K对点之间的长度
思路:Tarjan算法
代码
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,K,head[400001],cnt,fa[400001],dis[400001];
int qhead[400001];
bool vis[400001];
typedef struct node {
int to,next,v;
} node;
node e[400001],qe[400001];
void Add(int to,int from,int v) {//链式前向星
e[cnt].to=to;
e[cnt].v=v;
e[cnt].next=head[from];
head[from]=cnt++;
}
void Addq(int to,int from,int v) {//离线处理数据,按顺序记录每条边
qe[cnt].to=to;
qe[cnt].v=v;
qe[cnt].next=qhead[from];
qhead[from]=cnt++;
}
int Find(int x) {
if(x==fa[x])
return x;
return fa[x]=Find(fa[x]);
}
void LCA(int u) {
fa[u]=u;
vis[u]=1;
for(int i=head[u]; i!=-1; i=e[i].next) {
int v=e[i].to;
if(!vis[v]) {
dis[v]=dis[u]+e[i].v;
LCA(v);
fa[v]=u;
}
}
for(int i=qhead[u]; i!=-1; i=qe[i].next) {
int v=qe[i].to;
if(vis[v]) {
qe[i].v=dis[u]+dis[v]-2*dis[Find(v)];//获得LCA
qe[i^1].v=qe[i].v;//i^1为i的相反边
}
}
}
int main() {
scanf("%d%d",&N,&M);
memset(head,-1,sizeof(head));
memset(qhead,-1,sizeof(qhead));
while(M--) {
int a,b,l;
scanf("%d%d%d",&a,&b,&l);
Add(a,b,l);
Add(b,a,l);
getchar();//方向无用
getchar();
}
scanf("%d",&K);
cnt=0;
while(K--) {
int x,y;
scanf("%d%d",&x,&y);
Addq(x,y,0);
Addq(y,x,0);
}
LCA(1);
for(int i=0; i<cnt; i+=2)//相邻两条边为相互取反
printf("%d\n",qe[i].v);
return 0;
}
用在线RMQ代码如下
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,K,head[400001],cnt,dep[400001],seq[400001],pos[400001],tot,F[400001][30],dis[400001];
bool vis[400001];
typedef struct node {
int to,next,v;
} node;
node e[400001];
void Add(int to,int from,int v) {//链式前向星
e[++cnt].to=to;
e[cnt].v=v;
e[cnt].next=head[from];
head[from]=cnt;
}
void DFS(int u,int d) {//构造欧拉序列
vis[u]=1;
pos[u]=++tot;
seq[tot]=u;
dep[tot]=d;
for(int i=head[u]; i; i=e[i].next) {
int v=e[i].to,w=e[i].v;
if(vis[v])
continue;
dis[v]=dis[u]+w;
DFS(v,d+1);
seq[++tot]=u;
dep[tot]=d;
}
}
void ST() {
for(int i=1; i<=tot; i++)
F[i][0]=i;
int k=log2(tot);
for(int j=1; j<=k; j++)
for(int i=1; i<=tot-(1<<j)+1; i++)
if(dep[F[i][j-1]]<dep[F[i+(1<<(j-1))][j-1]])
F[i][j]=F[i][j-1];
else
F[i][j]=F[i+(1<<(j-1))][j-1];
}
int RMQ(int l,int r) {
int k=log2(r-l+1);
if(dep[F[l][k]]<dep[F[r-(1<<k)+1][k]])
return F[l][k];
else
return F[r-(1<<k)+1][k];
}
int LCA(int x,int y) {
int l=pos[x],r=pos[y];
if(l>r)
swap(l,r);
return seq[RMQ(l,r)];
}
int main() {
scanf("%d%d",&N,&M);
while(M--) {
int a,b,l;
scanf("%d%d%d",&a,&b,&l);
Add(a,b,l);
Add(b,a,l);
getchar();//方向无用
getchar();
}
scanf("%d",&K);
DFS(1,1);
ST();
while(K--) {
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",dis[x]+dis[y]-2*dis[LCA(x,y)]);
}
return 0;
}
也可以用树上倍增法解决,代码如下
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,head[400001],cnt,K,F[400001][30],k,dis[400001],d[400001];
struct node {
int to,next,l;
} edge[400001];
void Add(int from,int to,int w) {
edge[++cnt].to=to;
edge[cnt].next=head[from];
edge[cnt].l=w;
head[from]=cnt;
}
void DFS(int u) {
for(int i=head[u]; i; i=edge[i].next) {
int v=edge[i].to;
if(v==F[u][0])
continue;
d[v]=d[u]+1;
k=max(k,d[v]);
dis[v]=dis[u]+edge[i].l;
F[v][0]=u;
DFS(v);
}
}
void ST() {
for(int j=1; j<=k; j++)
for(int i=1; i<=N; i++)
F[i][j]=F[F[i][j-1]][j-1];
}
int LCA(int x,int y) {//获得公共祖先
if(d[x]>d[y])
swap(x,y);
for(int i=k; i>=0; i--)
if(d[F[y][i]]>=d[x])
y=F[y][i];
if(x==y)
return x;
for(int i=k; i>=0; i--)
if(F[y][i]!=F[x][i])
x=F[x][i],y=F[y][i];
return F[x][0];
}
int main() {
scanf("%d%d",&N,&M);
while(M--) {
int a,b,l;
char d;
scanf("%d%d%d",&a,&b,&l);
cin >>d;
Add(a,b,l);
Add(b,a,l);
}
scanf("%d",&K);
DFS(1);
k=log2(k);
ST();
while(K--) {
int a,b;
scanf("%d%d",&a,&b);
printf("%d\n",dis[a]+dis[b]-2*dis[LCA(a,b)]);
}
return 0;
}
HDU2874
题目大意:给出一个森林,判断一对点是否连通,如果连通输出两者间距离
思路:由于给的是森林,就需要对每个根节点进行tarjan
代码
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
const int maxm = 1e4+50;
const int maxq = 1e6+50;
int n,m,c,head[maxm],qhead[maxm],cnt,fa[maxm],dis[maxm],vis[maxm],ans[maxq];
struct node {
int to,v,next;
} e[maxm*2];//注意内存大小,需要开两倍,在这里结构体有四个变量的话会内存超限
struct node1 {
int to,id,next;
} qe[maxq*2];
void Add(int to,int from,int v) {
e[cnt].to=to;
e[cnt].v=v;
e[cnt].next=head[from];
head[from]=cnt++;
}
void Addq(int to,int from,int k) {
qe[cnt].to=to;
qe[cnt].next=qhead[from];
qe[cnt].id=k;
qhead[from]=cnt++;
}
int Find(int x) {
if(x==fa[x])
return x;
return fa[x]=Find(fa[x]);
}
void LCA(int u,int deep,int root) {
fa[u]=u;
dis[u]=deep;
vis[u]=root;
for(int i=head[u]; ~i; i=e[i].next) {
int v=e[i].to;
if(vis[v]==-1) {
LCA(v,deep+e[i].v,root);
fa[v]=u;
}
}
for(int i=qhead[u]; ~i; i=qe[i].next) {
int v=qe[i].to;
if(vis[v]==root)
ans[qe[i].id]=dis[u]+dis[v]-2*dis[Find(v)];
}
}
int main() {
while(~scanf("%d%d%d",&n,&m,&c)) {
memset(vis,-1,sizeof(vis));
memset(fa,0,sizeof(fa));
memset(head,-1,sizeof(head));
memset(qhead,-1,sizeof(qhead));
memset(dis,0,sizeof(dis));
memset(ans,-1,sizeof(ans));
cnt=0;
while(m--) {
int i,j,k;
scanf("%d%d%d",&i,&j,&k);
Add(i,j,k);
Add(j,i,k);
}
cnt=0;
for(int k=0; k<c; k++) {
int i,j;
scanf("%d%d",&i,&j);
Addq(i,j,k);
Addq(j,i,k);
}
for(int i=1; i<=n; i++)
if(vis[i]==-1)
LCA(i,0,i);
for(int i=0; i<c; i++)
if(ans[i]==-1)
printf("Not connected\n");
else
printf("%d\n",ans[i]);
}
return 0;
}
对于节点的构造也可以这样写,相关代码修改为
struct node {
int to,v,next;
} e[20002],qe[2000002];//关键之处,内存大小
void LCA(int u,int deep,int root) {
fa[u]=u;
dis[u]=deep;
vis[u]=root;
for(int i=head[u]; i!=-1; i=e[i].next) {
int v=e[i].to;
if(vis[v]==-1) {
LCA(v,deep+e[i].v,root);
fa[v]=u;
}
}
for(int i=qhead[u]; i!=-1; i=qe[i].next) {
int v=qe[i].to;
if(vis[v]==root) {
qe[i].v=dis[u]+dis[v]-2*dis[Find(v)];
qe[i^1].v=qe[i].v;
}
}
}
for(int i=0; i<cnt; i+=2)
if(qe[i].v!=-1)
printf("%d\n",qe[i].v);
else
printf("Not connected\n")
总结
LCA的方法多种多样,但最根本的还是DFS和ST表的应用,大多数方法也是基于这两个算法而演变而成的,掌握这两个算法很关键