学习了LCA的两种解法:用RMQ 和 trajan离线解法
现讲一下离线解法:
个人觉得这个解法就是从根节点开始DFS,然后到达底部之后再回溯,回溯的时候在把一个一个的点加入并查集,这样的话保证的是在找到两个点之前不会把他们的lca的祖先压入并查集,至于到底是怎么回溯的,引用:
0
|
1
/ \
2 3
比如说在这里,如果0为根的话,那么1是2和3的父亲结点,0是1的父亲结点,0和1都是2和3的公共祖先结点,但是1才是最近的公共祖先结点,或者说1是2和3的所有祖先结点中距离根结点最远的祖先结点。
在求解最近公共祖先为问题上,用到的是Tarjan的思想,从根结点开始形成一棵深搜树,非常好的处理技巧就是在回溯到结点u的时候,u的子树已经遍历,这时候才把u结点放入合并集合中,这样u结点和所有u的子树中的结点的最近公共祖先就是u了,u和还未遍历的所有u的兄弟结点及子树中的最近公共祖先就是u的父亲结点。以此类推。。这样我们在对树深度遍历的时候就很自然的将树中的结点分成若干的集合,两个集合中的所属不同集合的任意一对顶点的公共祖先都是相同的,也就是说这两个集合的最近公共最先只有一个。对于每个集合而言可以用并查集来优化,时间复杂度就大大降低了,为O(n + q),n为总结点数,q为询问结点对数
注意:在求找到后lca是finds(之前经过的那个点);
PS: 当求一对数的lca的时候调用一次就行了;
当求多个的时候,可以先用数组存储,然后也是走一遍就行了;
RMQ解法:
用一个图就可以解决所有的疑问:
输出一对的lca:
RMQ:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int maxn = 10000 + 10;
int k;
int head[maxn << 1];
int in[maxn << 1];
int first[maxn << 1];
int vs[maxn << 1];
int depth[maxn << 1];
int dp[maxn << 1][20];
struct Node
{
int to;
int next;
}edge[maxn];
void add_edge(int i,int x,int y)
{
edge[i].to = y;
edge[i].next = head[x];
head[x] = i;
}
void Init()
{
memset(head,-1,sizeof(head));
k = 0;
memset(in,0,sizeof(in));
}
void dfs(int v,int d)
{
first[v] = k;
vs[k] = v;
depth[k ++] = d;
for(int i = head[v]; i != -1; i = edge[i].next)
{
dfs(edge[i].to,d + 1);
vs[k] = v;
depth[k ++] = d;
}
}
void RMQ_ST()
{
for(int i = 0; i < k; i ++)
dp[i][0] = i;
for(int j = 1; (1 << j) < k; j ++)
{
for(int i = 0; i + (1 << j) < k; i ++)
{
int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1];
dp[i][j] = depth[x] < depth[y] ? x : y;
}
}
}
int RMQ(int x,int y)
{
if(x > y)
{
int t = x;
x = y;
y = t;
}
int len = 0;
while(1 << ( len + 1) <= (y - x + 1))
len ++;
int xs = dp[x][len],ys = dp[y - (1 << len) + 1][len];
int t = depth[xs] < depth[ys] ? xs : ys;
// cout << t << endl;
return vs[t];
}
int main()
{
int Tcase;
scanf("%d",&Tcase);
for(int ii = 1; ii <= Tcase; ii ++)
{
Init();
int n;
scanf("%d",&n);
for(int i = 1; i < n; i ++)
{
int x,y;
scanf("%d%d",&x,&y);
add_edge(i,x,y);
in[y] ++;
}
for(int i = 1; i <= n; i ++)
{
if(!in[i])
{
dfs(i,1);
break;
}
}
RMQ_ST();
int x,y;
scanf("%d%d",&x,&y);
cout << RMQ(first[x],first[y]) << endl;
}
return 0;
}
trajan:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 40000 + 10;
int fa[maxn];
bool vis[maxn];
bool root[maxn];
vector<int>v[maxn];
int n,m;
void Init()
{
memset(root,true,sizeof(root));
memset(vis,false,sizeof(vis));
for(int i = 0; i < maxn ; i ++)
{
v[i].clear();
fa[i] = i;
}
}
int finds(int x)
{
return x == fa[x] ? x : (fa[x] = finds(fa[x]));
}
void unions(int x,int y)
{
int xx = finds(x), yy = finds(y);
if(xx != yy)
{
fa[yy] = xx;
}
}
int lca;
void Lca(int u,int x,int y)
{
for(int i = 0; i < v[u].size(); i ++)
{
if(!vis[v[u][i]])
{
Lca(v[u][i],x,y);
unions(u,v[u][i]);
}
}
vis[u] = true;
if(u == x && vis[y])
{
lca = finds(y);
return ;
}
if(u == y && vis[x])
{
lca = finds(x);
return;
}
}
int main()
{
int Tcase;
scanf("%d",&Tcase);
for(int ii = 1; ii <= Tcase; ii ++)
{
Init();
scanf("%d",&n);
for(int i = 1; i < n; i ++)
{
int x,y;
scanf("%d%d",&x,&y);
root[y] = false;
v[x].push_back(y);
}
int roots;
m = 1;
for(int i = 1; i <= n; i ++)
if(root[i])
{
roots = i;
break;
}
int x,y;
scanf("%d%d",&x,&y);
Lca(roots,x,y);
cout << lca << endl;
}
return 0;
}
题意: 输出一颗树的两个点的最短距离,
思路:求出lca,然后用dis[x] + dis[y] - 2 * dis[lca] * 2;
rmq:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int maxn = 40000 + 10;
int k;
int head[maxn << 1];
int in[maxn << 1];
int first[maxn << 1];
int vs[maxn << 1];
int depth[maxn << 1];
int dp[maxn << 1][30];
int dir[maxn << 1];
struct Node
{
int to;
int w;
int next;
}edge[maxn];
void add_edge(int i,int x,int y,int z)
{
edge[i].to = y;
edge[i].w = z;
edge[i].next = head[x];
head[x] = i;
}
void Init()
{
memset(dir,0,sizeof(dir));
memset(head,-1,sizeof(head));
k = 0;
memset(in,0,sizeof(in));
}
void dfs(int v,int d)
{
first[v] = k;
vs[k] = v;
depth[k ++] = d;
for(int i = head[v]; i != -1; i = edge[i].next)
{
dir[edge[i].to] = dir[v] + edge[i].w;
dfs(edge[i].to,d + 1);
vs[k] = v;
depth[k ++] = d;
}
}
void RMQ_ST()
{
for(int i = 0; i < k; i ++)
dp[i][0] = i;
for(int j = 1; (1 << j) < k; j ++)
{
for(int i = 0; i + (1 << j) < k; i ++)
{
int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1];
dp[i][j] = depth[x] < depth[y] ? x : y;
}
}
}
int RMQ(int x,int y)
{
if(x > y)
{
int t = x;
x = y;
y = t;
}
int len = (int)log((double)(y - x + 1)) /log(2.0);
int xs = dp[x][len],ys = dp[x + (1 << len)][len];
int t = depth[xs] < depth[ys] ? xs : ys;
// cout << t << endl;
return vs[t];
}
int main()
{
int Tcase;
scanf("%d",&Tcase);
for(int ii = 1; ii <= Tcase; ii ++)
{
Init();
int n,m;
scanf("%d%d",&n,&m);
for(int i = 1; i < n; i ++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add_edge(i,x,y,z);
in[y] ++;
}
for(int i = 1; i <= n; i ++)
{
if(!in[i])
{
dfs(i,1);
break;
}
}
RMQ_ST();
int x,y;
while(m --)
{
scanf("%d%d",&x,&y);
int lca = RMQ(first[x],first[y]);
// cout << lca << endl;
// cout << dir[x] << " "<< dir[y] << " " << dir[lca] << endl;
cout << dir[x] + dir[y] - 2 * dir[lca] << endl;
}
}
return 0;
}
#include<bits/stdc++.h>
using namespace std;
const int maxn = 40000 + 10;
int dis[maxn];
int fa[maxn];
bool vis[maxn];
struct Node
{
int to;
int w;
};
struct LCA
{
int x,y;
int lca;
}ans[maxn];
bool root[maxn];
vector<Node>v[maxn];
int n,m;
void Init()
{
memset(dis,0,sizeof(dis));
memset(ans,0,sizeof(vis));
memset(root,true,sizeof(root));
memset(vis,false,sizeof(vis));
for(int i = 0; i < maxn ; i ++)
{
v[i].clear();
fa[i] = i;
}
}
int finds(int x)
{
return x == fa[x] ? x : (fa[x] = finds(fa[x]));
}
void unions(int x,int y)
{
int xx = finds(x), yy = finds(y);
if(xx != yy)
{
fa[yy] = xx;
}
}
void Lca(int u)
{
for(int i = 0; i < v[u].size(); i ++)
{
if(!vis[v[u][i].to])
{
dis[v[u][i].to] = dis[u] + v[u][i].w;
Lca(v[u][i].to);
unions(u,v[u][i].to);
}
}
vis[u] = true;
for(int i = 0; i < m; i ++)
{
if(!ans[i].lca)
{
if(u == ans[i].x && vis[ans[i].y])
{
ans[i].lca = finds(ans[i].y);
}
if(u == ans[i].y && vis[ans[i].x])
{
ans[i].lca = finds(ans[i].x);
}
}
}
}
int main()
{
int Tcase;
scanf("%d",&Tcase);
for(int ii = 1; ii <= Tcase; ii ++)
{
Init();
scanf("%d%d",&n,&m);
for(int i = 1; i < n; i ++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
root[y] = false;
Node t;
t.to = y;
t.w = z;
v[x].push_back(t);
}
int roots;
for(int i = 1; i <= n; i ++)
if(root[i])
{
roots = i;
break;
}
for(int i = 0; i < m; i ++)
{
scanf("%d%d",&ans[i].x,&ans[i].y);
}
Lca(roots);
for(int i = 0; i < m; i ++)
cout << dis[ans[i].x] + dis[ans[i].y] - 2 * dis[ans[i].lca] << endl;
}
return 0;
}