LCA笔记

算法思想

最近公共祖先指有根树中距离两节点最近的公共祖先,祖先为当前节点早树根路径上的所有节点(包括自己),可以用LCA求解树上任意两点u,v的距离,公式为dist[u]+dist[v]-2×dist[lca]

暴力

1.向上标记法

从u向上遍历,一直到根节点,标记所有经过的节点,从v开始遍历,若v标记,v为LCA,当遇到第一次标记的节点,该节点为LCA

2.同步前进法

将u和v中的深度大的节点上升到与另一个相同的深度,然后一同前进,走到同一个节点时便得到LCA

树上倍增

树上倍增法为加上倍增思想的同步前进法,同样使用了ST表,F[i,j]表示i的2j辈祖先,F[i,0]为父节点,采用分治的思想,对于F[i,j],分成两个步骤,i节点先向根节点走2j-1步到F[i,j-1],之后再走2j-1步得到 F [ i , j ] = F [ F [ i , j − 1 ] , j − 1 ] F[i,j]=F[F[i,j-1],j-1] F[i,j]=F[F[i,j1],j1]

对于问题的求解,首先需要将u,v上升到同一深度,按照增量递减,若到达的节点深度比另一个深度小,不操作,到达的节点深度大于等于另一个深度,该点上移,到达同一深度后,同样按照增量递减的思路,到达节点相同,不操作,到达节点不同,上移,直到增量为0

实现代码如下

void ST()
{
	for(int j=1;j<=k;j++)//k为最大深度的log2值
		for(int i=1;i<=n;i++)
			F[i][j]=F[F[i][j-1]][j-1];
}
int LCA(int x,int y)
{
	if(d[x]>d[y])
		swap(x,y);
	for(int i=k;i>=0;i--)
		if(d[F[y][i]]>=d[x])
			y=F[y][i];
	if(x==y)
		return x;
	for(int i=k;i>=0;i--)
		if(F[y][i]!=F[x][i])
			y=F[y][i],x=F[x][i];
	return F[y][0];
}

在线RMQ

欧拉序列:DFS过程中依次将经过的节点记录,同样也记录回溯经过的节点,一个节点可能被记录多次,相当于从树根开始一笔画出一个经过所有节点的回路,其实就是前序搜索树

两个节点的LCA一定是两节点(首次出现下标间区间)间欧拉序列中深度最小节点,如下图,DFS序列为1 2 4 6 8 6 9 6 4 2 5 7 5 2 1 3 1,6和5首次出现下标为4和11,深度最小节点为2

在这里插入图片描述在线RMQ首先要构造出欧拉序列,代码如下

void dfs(int u,int d)
{
	vis[u]=1;
	pos[u]=++tot;//记录首次出现的下标,从0开始记录,对于每个大小只会记录一次
	seq[tot]=u;//记录欧拉序列
	dep[tot]=d;//记录该点深度
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to,w=e[i].c;
		if(vis[v]) continue;
		dist[v]=dist[u]+w;
		dfs(v,d+1);
		seq[++tot]=u;//对回溯的节点的操作,回溯节点无pos记录
		dep[tot]=d;//同上
	}
}

创造完欧拉序列后,由于最后求解最小深度,因此对深度的数组创建区间最值查询ST,F(i,j)表示[i,i+2j-1]区间深度最小的节点下标

void ST()
{
	for(int i=1;i<=tot;i++)
		F[i][0]=i;//记录下标,不是深度最小值
	int k=log2(tot);
	for(int j=1;j<=k;j++)
		for(int i=1;i<=n-(1<<j)+1;i++)
			F[i][j]=dep[F[i][j-1]]<dep[F[i+(1<<(j-1))][j-1]]?F[i][j-1]:F[i+(1<<(j-1))][j-1];
}

查询代码如下

int RMQ(int l,int r)
{
	int k=log2(r-l+1);
	if(dep[F[l][k]]<dep[F[r-(1<<k)+1][k]])
		return F[l][k];
	else
		return F[r-(1<<k)+1][k];//注意,返回的是欧拉序列的下标
}
int LCA(int x,int y)
{
	int l=pos[x],r=pos[y];
	if(l>r) swap(l,r);
	return seq[RMQ(l,r)];//返回节点号
}

Tarjan

Tarjan是离线算法,读入所有查询,然后运行程序一次性得到所有查询答案,步骤如下

  1. 初始化集合号数组和访问数组,fa[i]=i,vis[i]=0
  2. 从节点u出发DFS,标记vis[u]=1,DFSu所有未访问的邻接点,遍历过程中更新距离,回退时更新集合号
  3. u的邻接点全部遍历完,检查u的所有查询,若存在一个查询u、v,vis[v]=1,利用并查集找到v的祖宗,找到的节点即u、v的最近公共祖先

当前节点u的邻接点已访问完毕时,检查u相关的所有查询v,若vis[v]≠1,不操作,否则用并查集查找v的祖宗,lca(u,v)=fa[v]。u的祖宗就是u向上查找第一个邻接点未访问完的节点,其fa[]未更新,仍满足fa[i]=i,它为v的祖宗

实现代码如下

int find(int x)
{
	if(x!=fa[x])
		return x;
	return fa[x]=find(fa[x]);
}
void tarjan(int u)
{
	vis[u]=1;
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].to,w=e[i].c;
		if(vis) continue;
		dis[v]=dis[u]+w;
		tarjan(v);
		fa[v]=u;
	}
	for(int i=0;i<query[u].size();i++)//u相关的所有查询,运行到这里的时候u的子树都已经查完了
	{
		int v=query[u][i];
		int id=query_id[u][i];//记录查询结果,获得查询编号
		if(vis[v])
		{
			int lca=find(v);
			ans[id]=dis[u]+dis[v]-2*dis[lca];//uv间距
		}
	}
}

训练

POJ1330

题目大意:给出一棵树,对于给定查询输出两个节点的LCA

思路:由于所给数据量小,因此可以用暴力(只有一个查询)

代码

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
using namespace std;
int T,N,fa[121212],u,v;
bool vis[121212];
int main() {
    scanf("%d",&T);
    while(T--) {
        scanf("%d",&N);
        for(int i=1; i<=N; i++)
            fa[i]=i;
        N--;//数据是N-1
        while(N--) {
            scanf("%d%d",&u,&v);
            fa[v]=u;
        }
        scanf("%d%d",&u,&v);
        vis[u]=1;
        while(fa[u]!=u) {
            u=fa[u];
            vis[u]=1;
        }
        if(vis[v]) {//如果已经遍历过了
            printf("%d\n",v);
            continue;
        }
        while(fa[v]!=v) {
            v=fa[v];
            if(vis[v]) {
                printf("%d\n",v);
                break;
            }
        }
        memset(vis,0,sizeof(vis));
    }
    return 0;
}

HDU2586

题目大意:n个点,点间有一些双向道路,没有两条路连接同一个节点,每条路有自己的长度,给出多个查询,询问从A点到B点需要走多远

思路:树上倍增+ST

代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int T,head[400001],cnt,F[400001][30],dis[400001],d[400001],k,n,m;//注意空间大小
struct node {
    int to,l,next;
} edge[400001];
void Add(int from,int to,int w) {
    edge[++cnt].to=to;
    edge[cnt].l=w;
    edge[cnt].next=head[from];
    head[from]=cnt;
}//链式前向星存边
void DFS(int u) {
    for(int i=head[u]; i; i=edge[i].next) {
        int v=edge[i].to;
        if(F[u][0]==v)
            continue;
        d[v]=d[u]+1;
        k=max(d[v],k);
        dis[v]=dis[u]+edge[i].l;
        F[v][0]=u;
        DFS(v);
    }
}
void ST() {
    for(int j=1; j<=k; j++)
        for(int i=1; i<=n; i++)
            F[i][j]=F[F[i][j-1]][j-1];//构造ST
}
int LCA(int x,int y) {
    if(d[x]>d[y])
        swap(x,y);
    for(int i=k; i>=0; i--)//如果y的深度大
        if(d[F[y][i]]>=d[x])
            y=F[y][i];
    if(x==y)
        return x;
    for(int i=k; i>=0; i--)//此时x,y的深度相同
        if(F[x][i]!=F[y][i])
            x=F[x][i],y=F[y][i];
    return F[x][0];
}
int main() {
    scanf("%d",&T);
    while(T--) {
        scanf("%d%d",&n,&m);
        for(int i=1; i<n; i++) {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            Add(a,b,c);
            Add(b,a,c);
        }
        DFS(1);
        k=log2(k);
        ST();
        while(m--) {
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",dis[x]+dis[y]-2*dis[LCA(x,y)]);//求距离公式
        }
        memset(dis,0,sizeof(dis));
        memset(d,0,sizeof(d));
        memset(head,0,sizeof(head));
        for(int i=1; i<=cnt; i++)
            edge[i].l=edge[i].next=edge[i].to=0;
        cnt=0;
        k=0;
    }
    return 0;
}

POJ1986

题目大意:N个点,给出M条连接两个点之间的边,给出边的长度,给出K个查询,查询K对点之间的长度

思路:Tarjan算法

代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,K,head[400001],cnt,fa[400001],dis[400001];
int qhead[400001];
bool vis[400001];
typedef struct node {
    int to,next,v;
} node;
node e[400001],qe[400001];
void Add(int to,int from,int v) {//链式前向星
    e[cnt].to=to;
    e[cnt].v=v;
    e[cnt].next=head[from];
    head[from]=cnt++;
}
void Addq(int to,int from,int v) {//离线处理数据,按顺序记录每条边
    qe[cnt].to=to;
    qe[cnt].v=v;
    qe[cnt].next=qhead[from];
    qhead[from]=cnt++;
}
int Find(int x) {
    if(x==fa[x])
        return x;
    return fa[x]=Find(fa[x]);
}
void LCA(int u) {
    fa[u]=u;
    vis[u]=1;
    for(int i=head[u]; i!=-1; i=e[i].next) {
        int v=e[i].to;
        if(!vis[v]) {
            dis[v]=dis[u]+e[i].v;
            LCA(v);
            fa[v]=u;
        }
    }
    for(int i=qhead[u]; i!=-1; i=qe[i].next) {
        int v=qe[i].to;
        if(vis[v]) {
            qe[i].v=dis[u]+dis[v]-2*dis[Find(v)];//获得LCA
            qe[i^1].v=qe[i].v;//i^1为i的相反边
        }
    }
}
int main() {
    scanf("%d%d",&N,&M);
    memset(head,-1,sizeof(head));
    memset(qhead,-1,sizeof(qhead));
    while(M--) {
        int a,b,l;
        scanf("%d%d%d",&a,&b,&l);
        Add(a,b,l);
        Add(b,a,l);
        getchar();//方向无用
        getchar();
    }
    scanf("%d",&K);
    cnt=0;
    while(K--) {
        int x,y;
        scanf("%d%d",&x,&y);
        Addq(x,y,0);
        Addq(y,x,0);
    }
    LCA(1);
    for(int i=0; i<cnt; i+=2)//相邻两条边为相互取反
        printf("%d\n",qe[i].v);
    return 0;
}

用在线RMQ代码如下

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,K,head[400001],cnt,dep[400001],seq[400001],pos[400001],tot,F[400001][30],dis[400001];
bool vis[400001];
typedef struct node {
    int to,next,v;
} node;
node e[400001];
void Add(int to,int from,int v) {//链式前向星
    e[++cnt].to=to;
    e[cnt].v=v;
    e[cnt].next=head[from];
    head[from]=cnt;
}
void DFS(int u,int d) {//构造欧拉序列
    vis[u]=1;
    pos[u]=++tot;
    seq[tot]=u;
    dep[tot]=d;
    for(int i=head[u]; i; i=e[i].next) {
        int v=e[i].to,w=e[i].v;
        if(vis[v])
            continue;
        dis[v]=dis[u]+w;
        DFS(v,d+1);
        seq[++tot]=u;
        dep[tot]=d;
    }
}
void ST() {
    for(int i=1; i<=tot; i++)
        F[i][0]=i;
    int k=log2(tot);
    for(int j=1; j<=k; j++)
        for(int i=1; i<=tot-(1<<j)+1; i++)
            if(dep[F[i][j-1]]<dep[F[i+(1<<(j-1))][j-1]])
                F[i][j]=F[i][j-1];
            else
                F[i][j]=F[i+(1<<(j-1))][j-1];
}
int RMQ(int l,int r) {
    int k=log2(r-l+1);
    if(dep[F[l][k]]<dep[F[r-(1<<k)+1][k]])
        return F[l][k];
    else
        return F[r-(1<<k)+1][k];
}
int LCA(int x,int y) {
    int l=pos[x],r=pos[y];
    if(l>r)
        swap(l,r);
    return seq[RMQ(l,r)];
}
int main() {
    scanf("%d%d",&N,&M);
    while(M--) {
        int a,b,l;
        scanf("%d%d%d",&a,&b,&l);
        Add(a,b,l);
        Add(b,a,l);
        getchar();//方向无用
        getchar();
    }
    scanf("%d",&K);
    DFS(1,1);
    ST();
    while(K--) {
        int x,y;
        scanf("%d%d",&x,&y);
        printf("%d\n",dis[x]+dis[y]-2*dis[LCA(x,y)]);
    }
    return 0;
}

也可以用树上倍增法解决,代码如下

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
int N,M,head[400001],cnt,K,F[400001][30],k,dis[400001],d[400001];
struct node {
    int to,next,l;
} edge[400001];
void Add(int from,int to,int w) {
    edge[++cnt].to=to;
    edge[cnt].next=head[from];
    edge[cnt].l=w;
    head[from]=cnt;
}
void DFS(int u) {
    for(int i=head[u]; i; i=edge[i].next) {
        int v=edge[i].to;
        if(v==F[u][0])
            continue;
        d[v]=d[u]+1;
        k=max(k,d[v]);
        dis[v]=dis[u]+edge[i].l;
        F[v][0]=u;
        DFS(v);
    }
}
void ST() {
    for(int j=1; j<=k; j++)
        for(int i=1; i<=N; i++)
            F[i][j]=F[F[i][j-1]][j-1];
}
int LCA(int x,int y) {//获得公共祖先
    if(d[x]>d[y])
        swap(x,y);
    for(int i=k; i>=0; i--)
        if(d[F[y][i]]>=d[x])
            y=F[y][i];
    if(x==y)
        return x;
    for(int i=k; i>=0; i--)
        if(F[y][i]!=F[x][i])
            x=F[x][i],y=F[y][i];
    return F[x][0];
}
int main() {
    scanf("%d%d",&N,&M);
    while(M--) {
        int a,b,l;
        char d;
        scanf("%d%d%d",&a,&b,&l);
        cin >>d;
        Add(a,b,l);
        Add(b,a,l);
    }
    scanf("%d",&K);
    DFS(1);
    k=log2(k);
    ST();
    while(K--) {
        int a,b;
        scanf("%d%d",&a,&b);
        printf("%d\n",dis[a]+dis[b]-2*dis[LCA(a,b)]);
    }
    return 0;
}

HDU2874

题目大意:给出一个森林,判断一对点是否连通,如果连通输出两者间距离

思路:由于给的是森林,就需要对每个根节点进行tarjan

代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <cstring>
using namespace std;
const int maxm = 1e4+50;
const int maxq = 1e6+50;
int n,m,c,head[maxm],qhead[maxm],cnt,fa[maxm],dis[maxm],vis[maxm],ans[maxq];
struct node {
    int to,v,next;
} e[maxm*2];//注意内存大小,需要开两倍,在这里结构体有四个变量的话会内存超限
struct node1 {
    int to,id,next;
} qe[maxq*2];
void Add(int to,int from,int v) {
    e[cnt].to=to;
    e[cnt].v=v;
    e[cnt].next=head[from];
    head[from]=cnt++;
}
void Addq(int to,int from,int k) {
    qe[cnt].to=to;
    qe[cnt].next=qhead[from];
    qe[cnt].id=k;
    qhead[from]=cnt++;
}
int Find(int x) {
    if(x==fa[x])
        return x;
    return fa[x]=Find(fa[x]);
}
void LCA(int u,int deep,int root) {
    fa[u]=u;
    dis[u]=deep;
    vis[u]=root;
    for(int i=head[u]; ~i; i=e[i].next) {
        int v=e[i].to;
        if(vis[v]==-1) {
            LCA(v,deep+e[i].v,root);
            fa[v]=u;
        }
    }
    for(int i=qhead[u]; ~i; i=qe[i].next) {
        int v=qe[i].to;
        if(vis[v]==root)
            ans[qe[i].id]=dis[u]+dis[v]-2*dis[Find(v)];
    }
}
int main() {
    while(~scanf("%d%d%d",&n,&m,&c)) {
        memset(vis,-1,sizeof(vis));
        memset(fa,0,sizeof(fa));
        memset(head,-1,sizeof(head));
        memset(qhead,-1,sizeof(qhead));
        memset(dis,0,sizeof(dis));
        memset(ans,-1,sizeof(ans));
        cnt=0;
        while(m--) {
            int i,j,k;
            scanf("%d%d%d",&i,&j,&k);
            Add(i,j,k);
            Add(j,i,k);
        }
        cnt=0;
        for(int k=0; k<c; k++) {
            int i,j;
            scanf("%d%d",&i,&j);
            Addq(i,j,k);
            Addq(j,i,k);
        }
        for(int i=1; i<=n; i++)
            if(vis[i]==-1)
                LCA(i,0,i);
        for(int i=0; i<c; i++)
            if(ans[i]==-1)
                printf("Not connected\n");
            else
                printf("%d\n",ans[i]);
    }
    return 0;
}

对于节点的构造也可以这样写,相关代码修改为

struct node {
    int to,v,next;
} e[20002],qe[2000002];//关键之处,内存大小

void LCA(int u,int deep,int root) {
    fa[u]=u;
    dis[u]=deep;
    vis[u]=root;
    for(int i=head[u]; i!=-1; i=e[i].next) {
        int v=e[i].to;
        if(vis[v]==-1) {
            LCA(v,deep+e[i].v,root);
            fa[v]=u;
        }
    }
    for(int i=qhead[u]; i!=-1; i=qe[i].next) {
        int v=qe[i].to;
        if(vis[v]==root) {
            qe[i].v=dis[u]+dis[v]-2*dis[Find(v)];
            qe[i^1].v=qe[i].v;
        }
    }
}
	for(int i=0; i<cnt; i+=2)
            if(qe[i].v!=-1)
                printf("%d\n",qe[i].v);
            else
                printf("Not connected\n")

总结

LCA的方法多种多样,但最根本的还是DFS和ST表的应用,大多数方法也是基于这两个算法而演变而成的,掌握这两个算法很关键

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值