向上标记法 单次查询O(n)
步骤:
先从点1往上走到根节点,走过的点都标记
再从点2往上走,碰到的第一个带标记的点就是最近公共祖先
缺点:
求LCA最直接的方法,单次查询的时间复杂度最坏为 O(n)(看起来好像还挺快的,不过什么题会只有一次查询呢,因此很少使用)查询方式是从一个点向上搜,到和另一个点深度相同的地方就一起搜直到搜到。
倍增法:预处理(nlogn) + 单次查询(logn)
向上标记法之所以求LCA慢,是因为这dd一次只能爬一格,那么我们一次爬很多格不就OK了,二进制拆分可以组成任何想要的数,因此,用二进制来优化
int fa[N][16]; //fa[i][j]表示结点i往上跳2^j步所到达的结点
int depth[N]; //表示每个结点到根节点的深度
关键是理解二进制拼凑 在这里是怎么体现的
即 x,y从同一高度同时起跳后,在f[x][0]!=f[y][0] 的约束下 我们能跳的最多的步数跳完后 x,y就达到了LCA的下面一层
假定我们知道 x,y出发点为第1层 , LCA下一层为第12层
那么最多能跳的步数t = 12-1 = 11 = (1011)2 = 最多能跳2^3 + 2^2 + 2^0 步
所以我们就通过从大到小枚举k使得我们刚好跳11步而不能跳超过12步
但实际上我们并不知道要跳11步,所以我们可以通过f[x][0]!=f[y][0]的约束来实现
即f[x][k] = f[y][k] 那就不跳(不拼凑2^k),跳过头了,可能跳到LCA上面的其他公共祖先了
f[x][k] != f[y][k] 那就跳(拼凑2^k),没跳过LCA,可以跳
预处理:
求出每个点向上走2^k步的节点是谁
则f[i][j] = f[mid][j-1] = f[f[i][j-1]][j-1] ,先跳前面j-1步,再跳后面j-1步
查询:
步骤1 把两个点跳到同一层 把x跳到和y同一层
步骤2 两个同层不同结点跳到LCA的下一层
why 最近公共祖先的下一层 not 最近公共祖先?
方便判断,假如f[x][k] == f[y][k] <=> f[x][k] or f[y][k]是x和y的一个公共祖先 但不一定是最近的,举个栗子
此时f[x][1] == f[y][1] = 节点2 是x和y的一个公共祖先 但不是最近公共祖先4
,但由于我们是从大到小拼凑的,假如拼凑终止条件为f[x][k] == f[y][k]
,则此时会停在公共祖先2而非最近公共祖先4
1
/ \
2 3
/
4
/ \
x y
两个重要的哨兵:
(1)如a,b都跳出根节点,fa[a][k]==fa[b][k]==0
(2)depth[0] = 0: 如果从i开始跳2^j步会跳过根节点 ,那么fa[i][j] = 0 depth[fa[i][j]] = depth[0] = 0,则表示跳过了根节点,不往上跳了
练习题目:
链接:https://www.acwing.com/problem/content/1174/
AC代码:
#include<iostream>
#include<vector>
#include<queue>
#include<string.h>
using namespace std;
const int N=40005;
int n,m,root;
int fa[N][16]; //fa[i][j]表示结点i往上跳2^j步所到达的结点
int depth[N]; //表示每个结点到根节点的深度
vector<int>tr[N]; //邻接表
queue<int>tmp;
const int INF=0x3f3f3f3f;
void bfs(int s)
{
memset(depth,INF,sizeof(depth));
tmp.push(s);
depth[0]=0;depth[s]=1;
while(!tmp.empty())
{
int u=tmp.front();
tmp.pop();
for(int i=0;i<tr[u].size();i++)
{
int j=tr[u][i];
if(depth[j]==INF) //还没访问过,避免一个结点有多个父节点的情况
{
depth[j]=depth[u]+1;
tmp.push(j);
fa[j][0]=u; //注意这个不是移动0步,而是移动2^0=1步
for(int k=1;k<=15;k++)
{
fa[j][k]=fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int x,int y)
{
if(depth[x]<depth[y]) swap(x,y); //保证x比y深
//先将x移动到跟y一样深
for(int k=15;k>=0;k--)
{
// 设置了哨兵depth[0] = 0: 如果从i开始跳2^j步会跳过根节点
// fa[fa[j][k-1]][k-1] = 0
// 那么fa[i][j] = 0 depth[fa[i][j]] = depth[0] = 0,depth[y]>=1正是我们不进行处理的特殊情况
if(depth[fa[x][k]]>=depth[y]) //x移动2^k步后还是比y深,就继续往上移动
{
x=fa[x][k];
}
}
if(x==y) return x;
//将x和y一起往上移动到最近公共祖先处
for(int k=15;k>=0;k--)
{
// 假如a,b都跳出根节点,fa[a][k]==fa[b][k]==0 不符合更新条件,即不跳过去的特殊情况
if(fa[x][k]!=fa[y][k]) //先判断再移动,如果跳过去后已经跳过了LCA了就不跳
{
x=fa[x][k];
y=fa[y][k];
}
}
//最后出来的时候fa[x][0]或者fa[y][0]就为LCA
return fa[x][0];
}
int main()
{
cin>>n;
for(int i=0;i<n;i++)
{
int a,b;
cin>>a>>b;
tr[a].push_back(b);
tr[b].push_back(a);
if(b==-1) root=a;
}
bfs(root); //预处理nlog(n)
cin>>m;
for(int i=0;i<m;i++) //多次询问,每次询问O(logn)
{
int x,y;cin>>x>>y;
int p=lca(x,y);
if(p==x) cout<<1<<endl;
else if(p==y) cout<<2<<endl;
else cout<<0<<endl;
}
return 0;
}
tarjan离线求LCA O(n+多次查询m)
在线做法:边读边做
离线做法:先读完,再全部处理,最后全部输出。
tarjan的求法用的其实没有倍增法多,倍增法就比较快的了
Tarjon本质就是对向上标记法的一个优化,任取一个节点当成根节点进行dfs优先遍历,把所有节点分成三种类型
1)已经遍历并且回溯的标记成2
2)正在遍历的没有回溯的标记成1
3)未遍历的标记成0
注意:两个标记计算的顺序一定不能乱。
1)遍历了当前点的子节点后把子节点的祖宗更新成当前点
2)当遍历完当前点所有子树回溯到当前点的时候才可以用这个点来计算所有之前为2的点的查询
参考链接:https://blog.csdn.net/li1615882553/article/details/79762771
AC代码:
#include<iostream>
#include<vector>
#include<queue>
#include<string.h>
using namespace std;
const int N=20005,INF=0x3f3f3f3f;
int n,m,x,y,k;
typedef pair<int,int>PII;
vector<PII>tr[N]; //邻接表
vector<PII>query[N]; //存储查询 ,first存与下标index相关查询的另一个结点值,second存询问编号用于存储结果
int p[N]; //并查集
int ans[N],dist[N],st[N];
void dfs(int u,int fa)
{
for(int i=0;i<tr[u].size();i++)
{
int j=tr[u][i].first;
if(j==fa) continue; //遍历与u相连的除了u的父亲以外的结点
dist[j]=dist[u]+ tr[u][i].second;
dfs(j,u);
}
}
int find(int j)
{
if(p[j]!=j) p[j]=find(p[j]); //往上查找的同时合并并查集
return p[j];
}
void tarjan(int u)
{
st[u]=1; //正在遍历
for(int i=0;i<tr[u].size();i++)
{
int j=tr[u][i].first;
if(!st[j]) //没有遍历过
{
tarjan(j);
p[j]=u; //遍历完了一个子节点
}
}
//子树都遍历完,回溯到u结点,求与u相关的查询
for(int i=0;i<query[u].size();i++)
{
int j=query[u][i].first;
if(st[j]==2) //只有已经遍历完并回溯完的才能求值
{
int lca=find(j); //找到j所属并查集的祖先,进行合并O(1)
int id=query[u][i].second;
ans[id]=dist[u]+dist[j]-2*dist[lca];
}
}
st[u]=2; //遍历并回溯完了
}
int main()
{
cin>>n>>m;
for(int i=1;i<n;i++)
{
cin>>x>>y>>k;
tr[x].push_back({y,k});
tr[y].push_back({x,k});
}
for(int i=0;i<m;i++)
{
cin>>x>>y;
query[x].push_back({y,i});
query[y].push_back({x,i});
}
for(int i=1;i<=n;i++) p[i]=i;
dfs(1,-1); //随便令一个为根节点,求每个点到根节点的距离,根节点的父节点就令为-1
tarjan(1);
for(int i=0;i<m;i++)
{
cout<<ans[i]<<endl;
}
return 0;
}