目录
P3379 【模板】最近公共祖先(LCA)
暴力
操作步骤:
- 求出每个结点的深度;
- 询问两个结点是否重合,若重合,则LCA已经求出;
- 否则,选择两个点中深度较大的一个,并移动到它的父亲。
int LCA(int x,int y)
{
while(x!=y)
{
if(depth[x]>=depth[y]) x=fa[x];
else y=fa[y];
}
return x;
}
倍增法
操作步骤:
- 求出倍增数组;
- 把两个点移动到同一深度;
- 逐步试探出LCA。
#include<bits/stdc++.h>
using namespace std;
struct Edge
{
int to,next;
}edge[500005*2];//无向图,两倍开
int head[500005],grand[500005][21],depth[500005],lg[500001];
int cnt,n,m,s;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
void add(int x,int y)
{
edge[++cnt].to=y;
edge[cnt].next=head[x];
head[x]=cnt;
}
void dfs(int now,int fa)
{
depth[now]=depth[fa]+1;
grand[now][0]=fa;
for(int i=1;i<=lg[depth[now]];i++)
//for(int i=1;(1<<i)<=depth[now];i++)
grand[now][i]=grand[grand[now][i-1]][i-1];
//爸爸的爸爸叫爷爷~~~
for(int i=head[now];i;i=edge[i].next)
//遍历和当前结点相连的所有的边(按输入的倒序),最后一条边的 edge[i].next==0
{
cout<<"第"<<i<<"条边,指向" <<edge[i].to<<endl;
if(edge[i].to!=fa)
dfs(edge[i].to,now);
}
}
int LCA(int a,int b)
{
if(depth[a]<depth[b])
swap(a,b);
while(depth[a]>depth[b])
a=grand[a][lg[depth[a]-depth[b]]-1];
//倍增法逼近,e.g:depth[a]-depth[b]==14
//lg[depth[a]-depth[b]]-1==3,a上升8个深度,depth[a]-depth[b]==6;
//lg[depth[a]-depth[b]]-1==2,a上升4个深度,depth[a]-depth[b]==2;
//lg[depth[a]-depth[b]]-1==1,a上升2个深度,depth[a]-depth[b]==0;
if(a==b) return a;//a和b的LCA就是a
for(int k=lg[depth[a]]-1;k>=0;k--)
if(grand[a][k]!=grand[b][k])
a=grand[a][k],b=grand[b][k];
//从远古祖先(注意不要越界)中逐渐向最近的试探
// e.g:depth[a]==14,depth[LCA]==7;
// k=lg[depth[a]]-1,k==3;grand[a][k]==grand[b][k];continue;
//k==2,grand[a][k]!=grand[b][k],a,b一起向上4个深度;
//k==1,grand[a][k]!=grand[b][k],a,b一起向上2个深度;
//k==0,grand[a][k]!=grand[b][k],a,b一起向上1个深度;
//一共向上4+2+1==7个深度,找到LCA
return grand[a][0];
}
int main()
{
n=read(),m=read(),s=read();
for(int i=1;i<n;i++)
{
int a,b;
a=read(),b=read();
add(a,b);
add(b,a);
}
for(int i=1;i<=n;i++)
lg[i]=lg[i-1]+((1<<lg[i-1])==i);//log_{2}{i}+1
dfs(s,0);//从根结点开始搜索
while(m--)
{
int x,y;
x=read(),y=read();
printf("%d\n",LCA(x,y));
}
return 0;
}
RMQ+ST
(转化为欧拉序列上的RMQ问题,采用ST算法)
名词解释:
- 欧拉序列:每经过一个结点,都进行一次统计产生的DFS序列;
- RMQ:指的一类连续查询区间最小(最大)值的问题;
- ST算法:求解RMQ问题的算法。
不了解ST算法点这里:
P3865 【模板】ST 表https://www.luogu.com.cn/problem/P3865
ST表:
#include<bits/stdc++.h>
using namespace std;
int n,m,f[100010][20];
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
f[i][0]=read();
int t=log(n)/log(2);
for(int j=1;j<=t;j++)
for(int i=1;i<=n-(1<<j)+1;i++)
f[i][j]=max(f[i][j-1],f[i+(1<<(j-1))][j-1]);
while(m--)
{
int r,l;
l=read(),r=read();
int k=log(r-l+1)/log(2);
int ans=max(f[l][k],f[r-(1<<k)+1][k]);
printf("%d\n",ans);
}
return 0;
}
操作步骤:
- DFS求出欧拉序列和深度序列,以及每个结点在欧拉序列中第一次出现的位置;
- 找到查询的两个结点在欧拉序列中第一次出现的位置;
- 在深度序列中两个位置之间的区间找到深度最小的点。
P.S.:假如两个结点在欧拉序列中不止出现一次,只需要任选其中一次来计算即可。
//3.95s / 227.49MB / 1.67KB C++14 (GCC 9) O2
#include<bits/stdc++.h>
using namespace std;
const int N=500005;
vector<int>vec[N];
//记录每个结点可以走向哪些结点
int f[N*2][21],mem[N*2][21],depth[N*2],first[N],vis[N*2],lg[N];
//f:记录深度序列区间中的最小深度
//mem:记录 找到深度序列区间中的最小深度 时的对应结点(在欧拉序列中)
//depth:在dfs过程中记录遍历到每个点时的对应深度
//first:记录每个结点第一次出现时在欧拉序列中的位置
//vis:欧拉序列
//lg;lg[i]==log_{2}{i}+1
int cnt=0,n,m,s;
//cnt:每走到一个点计一次数
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
void dfs(int now,int dep)
{
if(!first[now]) first[now]=++cnt;//第一次遍历到该点
depth[cnt]=dep,vis[cnt]=now;
for(int i=0;i<vec[now].size();i++)
{
if(first[vec[now][i]]) continue;//是该结点的父节点,跳过
else dfs(vec[now][i],dep+1);
++cnt;
depth[cnt]=dep,vis[cnt]=now;//深搜完了vec[now][i]下的分支,回到当前结点 now
}
}
void RMQ()
{
for(int i=1;i<=cnt;i++)
{
lg[i]=lg[i-1]+((1<<lg[i-1])==i);
f[i][0]=depth[i];//区间长度为1时,该区间内深度的最小值就是该结点的深度
mem[i][0]=vis[i];
}
for(int j=1;(1<<j)<=cnt;j++)//枚举的区间长度倍增
for(int i=1;i+(1<<j)-1<=cnt;i++)//枚举合法的每个区间起点
{
if(f[i][j-1]<f[i+(1<<(j-1))][j-1])//深度最小的点在前半个区间
{
f[i][j]=f[i][j-1];
mem[i][j]=mem[i][j-1];
}
else//深度最小的后半个区间
{
f[i][j]=f[i+(1<<(j-1))][j-1];
mem[i][j]=mem[i+(1<<(j-1))][j-1];
}
}
}
int ST(int x,int y)
{
int l=first[x],r=first[y];//找到输入的两个结点编号对应在欧拉序列中第一次出现的位置
if(l>r) swap(l,r);
int k=lg[r-l+1]-1;
if(f[l][k]<f[r-(1<<k)+1][k]) return mem[l][k];
else return mem[r-(1<<k)+1][k];
}
int main()
{
n=read(),m=read(),s=read();
for(int i=1;i<n;i++)
{
int a,b;
a=read(),b=read();
vec[a].push_back(b);
vec[b].push_back(a);
}
dfs(s,0);//打表,给first、depth、vis赋值,给RMQ奠定基础
/*cout<<"各结点第一次出现的位置:"<<endl;
for(int i=1;i<=n;i++)
cout<<first[i]<<' ';
cout<<endl;
cout<<"欧拉序列:"<<endl;
for(int i=1;i<=2*n;i++)
cout<<vis[i]<<' ';
cout<<endl;
cout<<"深度序列"<<endl;
for(int i=1;i<=2*n;i++)
cout<<depth[i]<<' ';
cout<<endl;*/
RMQ();//打表,给f和mem赋值 ,给ST奠定基础
while(m--)
{
int x,y;
x=read(),y=read();
printf("%d\n",ST(x,y));
}
return 0;
}
算法效率:
预处理时间复杂度:
单次询问时间复杂度:
总时间复杂度:
空间复杂度:
Tarjan
(这还没学,随便写写)
操作步骤:
- DFS整棵树。每个结点x一开始属于只有该结点本身的集合
;
- DFS(x)时,每次访问子树y时,把
合并到
;
- x的所有子结点访问完,标记x为已访问;
- 遍历所有关于x的询问(x,y), 如果y已被访问,则这个询问的答案为并查集中的Find(y)。
void dfs(int x)
{
for(int i=0;i<g[x].size();i++)
{
dfs(g[x][i]);
uni(g[x][i],x);
}
vis[x]=1;
for(int i=0;i<query[x].size();i++)
{
int y=query[x][i];
if(vis[y]) ans[x][y]=find(y);
}
}
四个方法的优缺点比较:
(n个点,q次询问)
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int nex[N],head[N],ver[N],val[N],n,m,x,y,tot,f[N][31],depth[N],sum[N][31],lg[N];
void add(int x,int y)
{
ver[++tot]=y;
nex[tot]=head[x];
head[x]=tot;
val[x]++;
}
void dfs(int u,int fa)
{
depth[u]=depth[fa]+1;
f[u][0]=fa;
sum[u][0]=val[u];
for(int i=1;i<31;i++)
{
f[u][i]=f[f[u][i-1]][i-1];
sum[u][i]=sum[f[u][i-1]][i-1]+sum[u][i-1];
}
for(int i=head[u];i;i=nex[i])
if(ver[i]!=fa) dfs(ver[i],u);
}
int lca(int x,int y)
{
if(depth[x]>depth[y]) swap(x,y);
int tmp=depth[y]-depth[x],ans=0;
for(int j=0;tmp;j++,tmp>>=1)
if(tmp&1) ans+=sum[y][j],y=f[y][j];
if(y==x) return ans+val[y];
for(int j=30;j>=0&&y!=x;j--)
{
if(f[x][j]!=f[y][j])
{
ans+=sum[x][j]+sum[y][j];
x=f[x][j];y=f[y][j];
}
}
ans+=sum[x][0]+sum[y][0]+sum[f[x][0]][0];
return ans;
}
int main()
{
cin>>n>>m;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
while(m--)
{
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}