lca是在树上的概念,从根节点到某个点的最短路径上所有的点,都可以视为该点的祖先。而最近公共祖先就是对于两个点来说,距离最近的公共祖先,有两种情况,一种是这个祖先是其中一个点,另一种情况是这个祖先并非两点之中的任何一个点。
有三种算法可以求lca:
1.标记法,找出一个目标点的最短路,记录路径,然后从目标点往根节点遍历,标记最短路径,再从另一个点这么做一遍,当另一个点往前遍历的时候,最早遇到的被标记过的点就是所求。
一次查询O(n)
2.倍增法,这个方法比较巧妙,也是三种做法中最容易实现的。原理就是利用二进制优化,维护一个数组f[i][k]表示从i开始往上跳2^k个点后会落到哪个点。这个数组看似统计麻烦,但实际上可以用递推来实现:f[i][k]=f[f[i][k-1]][k-1];另外还有预处理每个点所在的层数,对于两个点,先将它们调整到同一层,然后再一起往上找,找到最小的f[i][k]!=f[j][k],然后再往上一步就是答案。
预处理:O(nlogn)
查询:O(logn)
3.Tarjan,这个用于离线处理(先全部读入,最后统一输出)lca问题,实际上是对标记法的一个优化。将图中所有的点分为3类:
0.没有被遍历过的点
1.遍历过还没有回溯过的点
2.遍历过且已经回溯过的点
对于正在遍历的点,如果要求它与已经回溯过的点之间的公共祖先,如图c与a,那么就是b;c与d,那么就是e,也即已回溯点所在子树的根节点,不用担心每次都找到根节点,如果是正在遍历的路径上的点如点e,由于还没有回溯,所以它的根节点就是它自己,所在在找d与c的时候就不会出差错。所以我们只要在回溯的时候再修改每个点的父节点即可实现。
时间复杂度O(n+m)(n是点数,m是询问数)
我们结合例题来具体分析一下。
1172. 祖孙询问(活动 - AcWing)
思路:这题很直白,求出两点的公共祖先,然后判断是否是两点中的任何一个。
#include<bits/stdc++.h>
using namespace std;
const int N=40010,M=80010;
int n,m,q;
int h[N],e[M],ne[M],idx;
int dep[N];
int fa[N][20];
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u,int f)
{
dep[u]=dep[f]+1;//0层置空
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==f) continue;
fa[j][0]=u;
for(int k=1;k<=15;k++) fa[j][k]=fa[fa[j][k-1]][k-1];
dfs(j,u);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=15;i>=0;i--)
{
if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
}
if(x==y) return x;
for(int i=15;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
int id[N];
int main()
{
scanf("%d",&m);
int r;
memset(h,-1,sizeof h);
for(int i=1;i<=m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
if(b!=-1) add(a,b),add(b,a);
else r=a;
}
dfs(r,r);
scanf("%d",&q);
while(q--)
{
int x,y;
scanf("%d%d",&x,&y);
int p=lca(x,y);
if(p==x) printf("1\n");
else if(p==y)printf("2\n");
else printf("0\n");
}
}
1171. 距离(活动 - AcWing)
这个题如果用倍增法,那么时间复杂度就O(nlogn+mlogn),虽然能写,但是这种询问比点还多的情况,我们可以用Tarjan算法来写。
#include<bits/stdc++.h>
using namespace std;
const int N=10010,M=20010;
int n,m;
int h[N],e[M],ne[M],w[M],idx;
int p[N],st[N],ans[M],d[N];
vector<pair<int,int>>q[N];
void add(int a,int b,int c)
{
w[idx]=c,e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u,int f)
{
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
d[j]=d[u]+w[i];
dfs(j,u);
}
}
int find(int x)
{
if(x!=p[x]) p[x]=find(p[x]);
return p[x];
}
void Tarjan(int u)
{
st[u]=1;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(!st[j])
{
Tarjan(j);
p[j]=u;
}
}
for(auto it:q[u])
{
int y=it.first,id=it.second;
if(st[y]==2)
{
int anc=find(y);
ans[id]=d[u]+d[y]-2*d[anc];
}
}
st[u]=2;
}
int main()
{
scanf("%d%d",&n,&m);
memset(h,-1,sizeof h);
for(int i=0;i<n-1;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c),add(b,a,c);
}
for(int i=0;i<m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
if(a!=b)
{
q[a].push_back({b,i});
q[b].push_back({a,i});
}
}
for(int i=1;i<=n;i++) p[i]=i;
dfs(1,-1);
Tarjan(1);
for(int i=0;i<m;i++) printf("%d\n",ans[i]);
return 0;
}
ps:ans数组要开到和询问一样大,一定要注意这一点。
356. 次小生成树(356. 次小生成树 - AcWing题库)
思路:求次小生成树问题,我们之前的做法是预处理树上任意两点之间的路径的边权最大值和最小值,然后枚举每条非树边 ,计算用这条边进行替换后的结果是多少,统计一个最小值。这个算法要对每个点进行一次dfs,时间复杂度还是有点高,这里我们可以用lca算法进行优化。具体步骤如下:
先kluskal求出最小生成树,然后根据最小生成树建一张无向图,然后用倍增法写lca,在lca预处理的时候,将d1[i][k]、d2[i][k](i上跳2^k的过程中的所经过的边的最大值和次大值)预处理出来,对所有的非树边执行lca查询,在lca查询的时候,将两个到最近公共祖先的过程中经过的所有最大值和次大值存下来,然后找出总的最大和次大,进行计算,返回结果。
这里用到的性质就是,树上两点之间的最短距离就是两者到lca的路径长度之和,所以我们可以通过lca的倍增算法来优化查找两者路径上边权最大和次大的长度。
#include<bits/stdc++.h>
using namespace std;
const int N=100010,M=300010,inf=0x3f3f3f3f;
int n,m;
struct edge{
int a,b,c;
bool flag;
bool operator<(const edge &x) const{
return c<x.c;
}
}ed[M];
int p[N];
int h[N],e[M],ne[M],w[M],idx;
int dep[N],fa[N][20],d1[N][20],d2[N][20];
int q[N];
void add(int a,int b,int c)
{
w[idx]=c,e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int find(int x)
{
if(p[x]!=x) p[x]=find(p[x]);
return p[x];
}
long long kluskal()
{
for(int i=1;i<=n;i++) p[i]=i;
sort(ed+1,ed+1+m);
memset(h,-1,sizeof h);
long long res=0;
for(int i=1;i<=m;i++)
{
int a=ed[i].a,b=ed[i].b,c=ed[i].c;
int pa=find(a),pb=find(b);
if(pa!=pb)
{
add(a,b,c),add(b,a,c);
p[pa]=pb;
res+=c;
ed[i].flag=1;
}
}
return res;
}
void dfs(int u,int f)
{
dep[u]=dep[f]+1;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
fa[j][0]=u;
d1[j][0]=w[i],d2[j][0]=-inf;
for(int k=1;k<=16;k++)
{
int anc=fa[j][k-1];
fa[j][k]=fa[anc][k-1];
int td[]={d1[j][k-1],d2[j][k-1],d1[anc][k-1],d2[anc][k-1]};
d1[j][k]=d2[j][k]=-inf;
for(int u=0;u<4;u++)
{
if(td[u]>d1[j][k]) d2[j][k]=d1[j][k],d1[j][k]=td[u];
else if(td[u]!=d1[j][k]&&td[u]>d2[j][k]) d2[j][k]=td[u];
}
}
dfs(j,u);
}
}
int dist[2*N];
int lca(int a,int b,int c)
{
int cnt=0;
if(dep[a]<dep[b]) swap(a,b);
for(int i=16;i>=0;i--)
{
if(dep[fa[a][i]]>=dep[b])
{
dist[cnt++]=d1[a][i];
dist[cnt++]=d2[a][i];
a=fa[a][i];
}
}
if(a!=b)
{
for(int i=16;i>=0;i--)
{
if(fa[a][i]!=fa[b][i])
{
dist[cnt++]=d1[a][i];
dist[cnt++]=d2[a][i];
dist[cnt++]=d1[b][i];
dist[cnt++]=d2[b][i];
a=fa[a][i],b=fa[b][i];
}
}
dist[cnt++]=d1[a][0];
dist[cnt++]=d1[b][0];
}
int td1=-inf,td2=-inf;
for(int i=0;i<cnt;i++)
{
int d=dist[i];
if(d>td1) td2=td1,td1=d;
else if(d!=td1&&d>td2) td2=d;
}
if(c>td1) return c-td1;
else if(c>td2) return c-td2;
else return inf;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
ed[i]={a,b,c};
}
long long sum=kluskal();
dfs(1,1);
long long mi=1e18;
for(int i=1;i<=m;i++)
{
if(!ed[i].flag)
{
mi=min(mi,sum+lca(ed[i].a,ed[i].b,ed[i].c));
}
}
cout<<mi;
}
352. 闇の連鎖(352. 闇の連鎖 - AcWing题库)
思路:这道题的主要边构成了一棵树,
如图的非树边a,如果第一次砍的是树上的红色路径,那么很显然第二次只能砍a;如果不是,第二次就可以
如果再加一条绿色的非树边,那么显然如果砍了两条路径重合的边,那么需要把两条路径都砍断才能分,如果砍了各自的路径,那么只用再砍对应的一条即可,如果砍了非路径的边,那么两条任砍其一即可。所以我们只需要直到每一条主要路径砍完还需要砍多少条非树边即可实现。
这里我们对于一条非树边,它们的最短路径上所有的边的值都要加1,这里可以利用差分实现,即,两点的d各加1,lca的d减2.
统计的时候就dfs即可,先一直往下搜,然后返回一个值表示以当前节点为根的子树的中砍边的情况。相当于把边上的权值赋给点,或者再仔细一点就是将每个点的入边上的值赋给这个点,每个点只有一个入边,在根节点处刚好全部抵消没有任何影响。(标记点搜点更容易)
#include<bits/stdc++.h>
using namespace std;
const int N=100010,M=200010;
int n,m;
int ans;
int dep[N];
int h[N],e[M],ne[M],idx;
int fa[N][20];
int d[N];
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int f)
{
dep[u]=dep[f]+1;
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
fa[j][0]=u;
for(int k=1;k<=16;k++)
{
int anc=fa[j][k-1];
fa[j][k]=fa[anc][k-1];
}
dfs1(j,u);
}
}
int lca(int a,int b)
{
if(dep[a]<dep[b])swap(a,b);
for(int i=16;i>=0;i--)
if(dep[fa[a][i]]>=dep[b])
a=fa[a][i];
if(a==b) return a;
for(int i=16;i>=0;i--)
if(fa[a][i]!=fa[b][i])
a=fa[a][i],b=fa[b][i];
return fa[a][0];
}
int dfs(int u,int f)
{
int res=d[u];
for(int i=h[u];~i;i=ne[i])
{
int j=e[i];
if(j==f) continue;
int s=dfs(j,u);
if(s==0) ans += m;
else if(s==1) ans+=1;
res += s;
}
return res;
}
int main()
{
scanf("%d%d",&n,&m);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs1(1,1);
for(int i=0;i<m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
int p=lca(a,b);
d[a]++,d[b]++,d[p]-=2;
}
dfs(1,1);
cout<<ans;
}