最近公共祖先学习笔记

y总讲解

 常用写法主要为(2)和(3),即倍增法和 Tarjan 法,其中倍增法比较好用。

模板:

1. 倍增法

先做一遍 bfsbfs 作用:

求出每一个点的深度,其中 0 号点的深度为 0,根节点的深度为 1

其中 0 号点相当于哨兵,当 fa[i][j] 跳过根节点时, fa[i][j]=0,所以其深度为 0,在之后的 lca 操作时更方便

将 dep 赋初值为 0x3f 可以使得节点一直向下走,不走回头路

dep 数组更新方式:dep[v]=dep[u]+1

fa[i][j] 数组表示当前节点 i ,向上走 2^j 步可以走到的节点,其中 fa[i][0] 表示节点 i 的父节点

fa 数组更新方式:fa[v][k]=fa[fa[v][k-1]][k-1]

先让节点 v 走 2^{k-1} 步,然后再走 2^{k-1} 步,共走  2^{k} 步

lca 函数即为求 a,b 两点的最近公共祖先

  1. 先让深度较大的点为 a
  2. 让 a 节点不断向上跳,直到 dep[a]==dep[b] ,即 a,b 节点深度相同。向上跳的时候要先跳大步,再跳小步,即倒着循环 k 
  3. 判断 a,b 是否为同一个点,若为同一个点,则返回 a;若不为同一个点,继续下面步骤
  4. 此时 a,b 两点深度相同,让他们同时向上跳 2^k 步,若二者的 fa[a][k]==fa[b][k] ,则说明 fa[a][k] 时二者的公共祖先,但不代表是最小的公共祖先。因此我们当 fa[a][k]!=fa[b][k] 时,二者共同向上跳 2^k 步,更新 a=fa[a][k],b=fa[b][k]
  5. 第 4 步的循环结束时,我们会得到 a,b,并且二者的最小公共祖先为 fa[a][0] ,我们此时返回 fa[a][0]
void bfs(){
	memset(depth,0x3f,sizeof(depth));
	depth[0]=0; depth[root]=1;
	int hh=0,tt=0;
	q[0]=root;
	while(hh<=tt){
		int u=q[hh++];
		for(int i=head[u];i;i=e[i].nex){
			int v=e[i].to;
			if(depth[v]>depth[u]+1){
				depth[v]=depth[u]+1;
				q[++tt]=v;
				fa[v][0]=u;
				for(int k=1;k<=15;k++)
					fa[v][k]=fa[fa[v][k-1]][k-1];
			}
		}
	}
}
int lca(int a,int b){
	if(depth[a]<depth[b]) swap(a,b);
	for(int k=15;k>=0;k--)
		if(depth[fa[a][k]]>=depth[b])
			a=fa[a][k];
	if(a==b) return a;
	for(int k=15;k>=0;k--)
		if(fa[a][k]!=fa[b][k]){
			a=fa[a][k];
			b=fa[b][k];
		}
	return fa[a][0];
}
int main(){
    bfs();
}
2. Tarjan 法

做法:

已经走过并且回溯过的点 vis 记为 2,走过但还未回溯的点 vis 记为 1,还未走过的点记为 0

我们观察上个图片,红色链为我们正在遍历的点,还未回溯,红色链上所有点的 p[x]==x ,即都等于他们自身,在回溯后才会对他们的父节点进行更新,且 u 回溯后才会使 vis[u]=2 。对于图上的 x,y 来说,我们发现其最小公共祖先为 y 所在的子树的根节点 t ,且根节点 t 在红色链上。

因此我们发现对于 vis[a]=1 ,所有 vis[b]=2 的节点,他们 lca(a,b)=find(b)

Tarjan 算法是离线处理所有询问,所以我们先将所有的问题保存下来,用 vector<pair<int,int> > q[N] 来存储问题,pair 第一维存节点编号,第二维存问题编号。当问 a,b 两点时,存储为 

q[a].push_back({b,i});
q[b].push_back({a,i});

做 tarjan 时,当前点 u 的所有子节点处理完之后再处理当前点 u。对于所有关于节点 u 的询问,只有当 v 的 vis[v]==2 时再进行操作处理。

res[i] 数组存的是第 i 个询问的答案

注意:并查集一定记得赋初值

void tarjan(int u){
	vis[u]=1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to,w=e[i].w;
		if(vis[v]) continue;
		tarjan(v);
		fa[v]=u;
	}
	for(int i=0;i<q[u].size();i++){
		int v=q[u][i].first,id=q[u][i].second;
		if(vis[v]!=2) continue;
		int anc=find(v);    
        res[id]=anc;
	}
	vis[u]=2;
}
int main(){
    tarjan(1);
}

总结:

  1. 求树上两点 x,y 距离时,ans=dis[x]+dix[y]-2*dis[anc]
  2. 树上差分:(详见例题 4 )给树上两点 x,y 之间的每一条路径的权值增加 c,则 d[x]+=c,d[y]+=c,d[anc]-=2*c。做完上述所有操作后,统计u \rightarrow v 边 i 的权值时,w[i] 等于以 v 为根的子树上所有点的 d[t] 累加和,即 \sum d[t]( t 属于以 v 为根的子树)。以下图为例

例题:

1. AcWing 1172. 祖孙询问

模板题,直接套板子就可以

#include<bits/stdc++.h>
using namespace std;
const int N=4e4+10,M=N*2;
struct node{
	int nex,to;
}e[M];
int head[N],depth[N],cnt,fa[N][16];
int root,q[N],n,m;
void add(int u,int v){
	e[++cnt].nex=head[u]; 
	e[cnt].to=v;
	head[u]=cnt;
}
void bfs(){
	memset(depth,0x3f,sizeof(depth));
	depth[0]=0; depth[root]=1;
	int hh=0,tt=0;
	q[0]=root;
	while(hh<=tt){
		int u=q[hh++];
		for(int i=head[u];i;i=e[i].nex){
			int v=e[i].to;
			if(depth[v]>depth[u]+1){
				depth[v]=depth[u]+1;
				q[++tt]=v;
				fa[v][0]=u;
				for(int k=1;k<=15;k++)
					fa[v][k]=fa[fa[v][k-1]][k-1];
			}
		}
	}
}
int lca(int a,int b){
	if(depth[a]<depth[b]) swap(a,b);
	for(int k=15;k>=0;k--)
		if(depth[fa[a][k]]>=depth[b])
			a=fa[a][k];
	if(a==b) return a;
	for(int k=15;k>=0;k--)
		if(fa[a][k]!=fa[b][k]){
			a=fa[a][k];
			b=fa[b][k];
		}
	return fa[a][0];
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		int a,b;
		scanf("%d%d",&a,&b);
		if(b==-1) root=a;
		else add(a,b),add(b,a);
	}
	bfs();
	scanf("%d",&m);
	while(m--){
		int a,b;
		scanf("%d%d",&a,&b);
		int p=lca(a,b);
		if(p==a) printf("1\n");
		else if(p==b) printf("2\n");
		else printf("0\n"); 
	}
	return 0;
}
2. AcWing 1171. 距离

树上两点之间距离模板题

#include<bits/stdc++.h>
using namespace std;
#define PII pair<int,int>
const int N=1e4+10,M=N*2;
struct node{
	int nex,to,w;
}e[M];
int head[N],cnt,dis[N];
int fa[N],res[M],vis[N],n,m;
vector<PII> q[N];
void add(int u,int v,int w){
	e[++cnt].nex=head[u];
	e[cnt].to=v;
	e[cnt].w=w;
	head[u]=cnt;
}
int find(int x){
	if(fa[x]==x) return fa[x];
	return fa[x]=find(fa[x]);
}
void dfs(int u,int father){
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to,w=e[i].w;
		if(v==father) continue;
		dis[v]=dis[u]+w;
		dfs(v,u);
	}
}
void tarjan(int u){
	vis[u]=1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to,w=e[i].w;
		if(vis[v]) continue;
		tarjan(v);
		fa[v]=u;
	}
	for(int i=0;i<q[u].size();i++){
		int v=q[u][i].first,id=q[u][i].second;
		if(vis[v]!=2) continue;
		int anc=find(v);
		res[id]=dis[u]+dis[v]-dis[anc]*2;
	}
	vis[u]=2;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++){
		int a,b,c;
		scanf("%d%d%d",&a,&b,&c);
		add(a,b,c); add(b,a,c);
	}
	for(int i=1;i<=m;i++){
		int a,b;
		scanf("%d%d",&a,&b);
		if(a==b) continue;
		q[a].push_back({b,i});
		q[b].push_back({a,i});
	}
	for(int i=1;i<=n;i++) fa[i]=i;
	dfs(1,-1);
	tarjan(1);
	for(int i=1;i<=m;i++)
		printf("%d\n",res[i]);
	return 0;
}
3. AcWing 356. 次小生成树

大致思路:

  • 先做一遍 Kruskal 算法,求出最小生成树的边权和为 sum
  • 对这颗最小生成树,我们枚举每一条非树边,边的两端点为 a,b ,权值为 w,求出这两个端点在最小生成树上的路径上的最大值 dis1 次大值 dis2 ,若 w>dis1 ,则 res=min(res,sum+(w-dis1));否则 res=min(res,sum+(w-dis2))

我们查找最大值最小值时用倍增法的 lca 来处理

设 d1[i][j] 为当前节点为 i,向上跳 2^j 步的路径上的最大值

设 d2[i][j] 为当前节点为 i,向上跳 2^j 步的路径上的次大值

预处理更新方式: 

for(int k=1;k<=16;k++){
	int anc=fa[v][k-1];
	fa[v][k]=fa[anc][k-1];
	int dis[4]={d1[v][k-1],d2[v][k-1],d1[anc][k-1],d2[anc][k-1]};
	d1[v][k]=d2[v][k]=-INF;
	for(int z=0;z<4;z++){
		int d=dis[z];
		if(d>d1[v][k]){
			d2[v][k]=d1[v][k];
			d1[v][k]=d;
		}
		else if(d!=d1[v][k]&&d>d2[v][k]) d2[v][k]=d;
	}
}

查找方式:

即把 a,b 两点路径上的所有 最大值和次大值 都加入到 dis[] 数组中作为可选的最大值次大值,然后我们遍历 dis[] 数组,找到其中的最大值 dis1 次大值 dis2 。

详见代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e5+10,M=3e5+10,INF=0x3f3f3f3f;
struct edges{
	int a,b,c,used;
}edge[M];
struct node{
	int nex,to,w;
}e[M*2];
int head[N],cnt,p[N],n,m;
int fa[N][17],d1[N][17],d2[N][17],dep[N];
int q[N];
bool cmp(edges t1,edges t2){
	return t1.c<t2.c;
}
void add(int u,int v,int w){
	e[++cnt].nex=head[u];
	e[cnt].to=v;
	e[cnt].w=w;
	head[u]=cnt;
}
int find(int x){
	if(p[x]==x) return p[x];
	return p[x]=find(p[x]);
}
ll Kruskal(){
	sort(edge+1,edge+1+m,cmp);
	ll res=0;
	for(int i=1;i<=m;i++){
		int a=edge[i].a,b=edge[i].b,c=edge[i].c;
		int pa=find(a),pb=find(b);
		if(pa==pb) continue;
		p[pb]=pa;
		edge[i].used=1;
		add(a,b,c); add(b,a,c);
		res+=edge[i].c;
	}
	return res;
}
void bfs(){
	memset(dep,0x3f,sizeof(dep));
	dep[0]=0; dep[1]=1;
	int hh=0,tt=0;
	q[0]=1;
	while(hh<=tt){
		int u=q[hh++];
		for(int i=head[u];i;i=e[i].nex){
			int v=e[i].to,w=e[i].w;
			if(dep[v]>dep[u]+1){
				dep[v]=dep[u]+1;
				fa[v][0]=u;
				q[++tt]=v;
				d1[v][0]=w; d2[v][0]=-INF;
				for(int k=1;k<=16;k++){
					int anc=fa[v][k-1];
					fa[v][k]=fa[anc][k-1];
					int dis[4]={d1[v][k-1],d2[v][k-1],d1[anc][k-1],d2[anc][k-1]};
					d1[v][k]=d2[v][k]=-INF;
					for(int z=0;z<4;z++){
						int d=dis[z];
						if(d>d1[v][k]){
							d2[v][k]=d1[v][k];
							d1[v][k]=d;
						}
						else if(d!=d1[v][k]&&d>d2[v][k]) d2[v][k]=d;
					}
				}
			}
			
		}
	}
}
int lca(int a,int b,int w){
	static int dis[200];
	int tot=0;
	if(dep[a]<dep[b]) swap(a,b);
	for(int k=16;k>=0;k--)
		if(dep[fa[a][k]]>=dep[b]){
			dis[++tot]=d1[a][k];
			dis[++tot]=d2[a][k];
			a=fa[a][k];
		}
	if(a!=b){
		for(int k=16;k>=0;k--){
			if(fa[a][k]!=fa[b][k]){
				dis[++tot]=d1[a][k];
				dis[++tot]=d2[a][k];
				dis[++tot]=d1[b][k];
				dis[++tot]=d2[b][k];
				a=fa[a][k]; b=fa[b][k];
			}
		}
		dis[++tot]=d1[a][0]; dis[++tot]=d1[b][0];
	}
	int dis1=-INF,dis2=-INF;
	for(int i=1;i<=tot;i++){
		int d=dis[i];
		if(d>dis1){
			dis2=dis1; dis1=d;
		}
		else if(d!=dis1&&d>dis2) dis2=d;
	}
	if(w>dis1) return w-dis1;
	return w-dis2;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++) p[i]=i;
	for(int i=1;i<=m;i++)
		scanf("%d%d%d",&edge[i].a,&edge[i].b,&edge[i].c);
	ll sum=Kruskal();
	bfs();
	ll res=1e18;
	for(int i=1;i<=m;i++){
		if(edge[i].used) continue;
		int a=edge[i].a,b=edge[i].b,c=edge[i].c;
		res=min(res,sum+lca(a,b,c));
	}
	printf("%lld\n",res);
	return 0;
}
4. AcWing 352. 闇の連鎖

树上差分模板题

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=N*2;
struct node{
	int nex,to;
}e[M];
int n,m,head[N],cnt;
int fa[N][17],dep[N],d[N];
int q[N],ans;
void add(int u,int v){
	e[++cnt].nex=head[u];
	e[cnt].to=v;
	head[u]=cnt;
}
void bfs(){
	memset(dep,0x3f,sizeof(dep));
	dep[0]=0; dep[1]=1;
	int hh=0,tt=0;
	q[0]=1;
	while(hh<=tt){
		int u=q[hh++];
		for(int i=head[u];i;i=e[i].nex){
			int v=e[i].to;
			if(dep[v]>dep[u]+1){
				dep[v]=dep[u]+1;
				q[++tt]=v;
				fa[v][0]=u;
				for(int k=1;k<=16;k++)
					fa[v][k]=fa[fa[v][k-1]][k-1];
			}
		}
	}
}
int lca(int a,int b){
	if(dep[a]<dep[b]) swap(a,b);
	for(int k=16;k>=0;k--)
		if(dep[fa[a][k]]>=dep[b])
			a=fa[a][k];
	if(a==b) return a;
	for(int k=16;k>=0;k--)
		if(fa[a][k]!=fa[b][k]){
			a=fa[a][k];
			b=fa[b][k];
		}
	return fa[a][0];
}
int dfs(int u,int father){
	int res=d[u];
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].to;
		if(v==father) continue;
		int t=dfs(v,u);
		if(t==0) ans+=m;
		else if(t==1) ans+=1;
		res+=t;
	}
	return res;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++){
		int a,b;
		scanf("%d%d",&a,&b);
		add(a,b); add(b,a);
	}
	bfs();
	for(int i=1;i<=m;i++){
		int a,b;
		scanf("%d%d",&a,&b);
		int p=lca(a,b);
		d[a]++; d[b]++; d[p]-=2;
	}
	dfs(1,-1);
	printf("%d\n",ans);
	return 0;
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值