求最近公共祖先的算法
首先得是一棵树,这样才有祖先的概念,
设A,B为树中的两个节点,
那么C为A,B的最近公共祖先
暴力算法
首先想到的方法是
标记其中一个节点到根节点的路径,
然后再遍历另一个节点向根节点方向的路径,
出现的第一个有标记的位置即是两点的最近公共祖先
这样的话复杂度为 n ^ 2 的
算法优化
然后基于该方法优化的算法就是 LCA 了,
首先分析上面的算法哪里慢,首先是第一个遍历的节点需用遍历到根节点,其实最好的是只遍历到两个节点的最近公平祖先那个位置就可以了,
还有一个地方就是向上遍历的过程中,是一个一个遍历的,这个速度太慢了,可以用 二进制DP 的思想去优化,
即DP[ i , j ]为从 i 节点向上遍历 2 ^ j 步的节点
那么就是两个节点中的低的那个先和另一个同深度,然后两个一块向上跳,这样的话就不用走没必要的路线了,
跳的话需用 二进制DP初始化 ,计算深度的话需用 广搜
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 40040, M = 2 * N;
int root;
int h[N], n[M], ne[M], idx;
int fa[N][16];
int depth[N];
int q[N], tt, hh;
void add(int a, int b){
n[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
tt = 0, hh = 0;
q[0] = root;
while(hh <= tt)
{
int t = q[hh ++];
for(int i = h[t]; ~i; i = ne[i])
{
int j = n[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[++ tt] = j;
fa[j][0] = t;
for(int k = 1 ; k <= 15 ; k ++ )
fa[j][k] = fa[ fa[j][k-1] ][k - 1];//j 向上跳2^k步 即是 j向上跳2^k-1步后再跳2^k-1步
}
}
}
return ;
}
int lca(int a, int b)
{
if(depth[a] > depth[b]) swap(a, b);
for(int i = 15; i >= 0;i --)
if(depth[fa[b][i]] >= depth[a])
b = fa[b][i];
if(a == b)
return a;
for(int i = 15;i >= 0;i --)
if(fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
return fa[a][0];
}
int main(){
int _;
memset(h, -1, sizeof h);
scanf("%d", &_);
for(int i = 0 ; i < _ ; i ++){
int a, b;
scanf("%d%d", &a ,&b);
if(b == -1)
root = a;
else
{
add (a, b);
add (b, a);
}
}
bfs();//预处理
scanf("%d",&_);
while(_--){
int a, b;
scanf("%d%d", &a, &b);
int t = lca(a, b);
if (t == a) puts("1");
else if (t == b) puts("2");
else puts("0");
}
return 0;
}
接下来就是一个离线的算法,(Tarjan)
即把查询先存起来,然后统一进行处理,统计输出
并查集是从底层向上添加的,
图中所示为求LCA(A,B)的过程,并查集中红色的边为正在搜索的边,灰色边(未连接)表示他们是分开的,并查集还并没有进行合并操作
如果以该节点为根节点的子树全部访问完后,就把该节点添加到他的上一层(并查集),这样的话就实现了从底层向上添加,
所求的 LCA 用邻接表现存好,如果访问该节点时该节点在邻接表中不为空,而且以另一个点为根节点的子树已经访问完成,另一个点的并查集根节点即为两个节点的最近公共祖先
把并查集和最近公共祖先的要求合并的非常完美
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef pair <int, int> PII;
const int N = 10010, M = 2 * N;
int h[N], w[M], e[M], ne[M], idx;
int head[N];
int dist[N];
int res[M];
int st[N];
vector <PII> query[N];
void add(int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}
int find(int t){
if(head[t] != t) head[t] = find(head[t]);
return head[t];
}
int dfs(int u, int fa){
for(int i = h[u] ; ~i ; i = ne[i]){
int j = e[i];
if(j == fa) continue;
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}
void tarjan(int u){
st[u] = 1;
for(int i = h[u]; ~i ; i = ne[i]){
int j = e[i];
if(!st[j]){
tarjan(j);//由于logn差不多是16,空间复杂度也没高到哪里去
head[j] = u;
}
}
for(auto item : query[u]){
int y = item.first, id = item.second;
if(st[y] == 2){//判断该节点是否已经完成并查集的更新
int anc = find(y);
res[id] = dist[u] + dist[y] - dist[anc] * 2;
}
}
st[u] = 2;
}
int main(){
int _;
int n, m;
int a, b, c;
memset(h, -1, sizeof h);
scanf("%d%d", &n, &m);
for(int i = 0;i < n - 1;i ++){
scanf("%d%d%d", &a, &b, &c);
add(a,b,c); add(b,a,c);
}
for(int i = 0;i < m;i ++ ){
scanf("%d%d", &a, &b);
if(a != b){
query[a].push_back({b, i});
query[b].push_back({a, i});
}
}
for(int i = 0;i <= n;i ++) head[i] = i;
dfs(1, -1);
tarjan(1);
for (int i = 0; i < m; i ++ ) printf("%d\n", res[i]);
return 0;
}
当然,这个题还可以用上面提到的 二进制跳 的方法解,用 LCA + 树状前缀和数组 去记录长度,就是复杂度没有上面的好
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define INF 0x3f3f3f3f
using namespace std;
const int N = 10010, M = 2 * N;
int h[N], e[M], ne[M], w[M], idx;
int n, m;
int fa[N][17], p[N];
int sum[N];
int depth[N];
void add(int a, int b, int c){
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}
void bfs(){
memset(sum, INF, sizeof sum);
memset(depth, 0x3f, sizeof depth);
int hh = 0, tt = 0;
p[0] = 1;
depth[1] = 1, depth[0] = 0;
sum[1] = 0, sum[0] = 0;
while(hh <= tt){
int t = p[hh ++];
// cout << "t = " << t << endl;
for(int i = h[t]; ~i ;i = ne[i]){
int j = e[i];
if(depth[j] > depth[t] + 1){
sum[j] = sum[t] + w[i];
depth[j] = depth[t] + 1;
p[ ++ tt] = j;
fa[j][0] = t;
for(int k = 1;k <= 16;k ++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
return ;
}
int lca(int a, int b){
if(depth[a] > depth[b]) swap(a, b);
for(int i = 16; i >= 0;i --)
if(depth[fa[b][i]] >= depth[a])
b = fa[b][i];
if(a == b) return a;
for(int i = 16;i >= 0;i --)
if(fa[a][i] != fa[b][i]){
a = fa[a][i];
b = fa[b][i];
}
return fa[a][0];
}
int main(){
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for(int i = 0;i < n-1;i ++){
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs();
//for(int i = 1;i <= n;i ++)
// cout << "i = " << i << " sum[i] = " << sum[i] << endl;
for(int i = 0;i < m;i ++){
int a, b;
scanf("%d%d", &a, &b);
int t = lca(a, b);
printf("%d\n",sum[a] + sum[b] - 2*sum[t]);
}
}
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
const int N = 100010, M = 2 * N;
int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][17];
int d[N];
int q[N];
int ans;
void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs(){
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
int hh = 0, tt = 0;
q[0] = 1;
while(hh <= tt){
int t = q[hh ++];
for(int i = h[t]; ~i ; i = ne[i]){
int j = e[i];
if(depth[j] > depth[t] + 1){
depth[j] = depth[t] + 1;
q[ ++ tt ] = j;
fa[j][0] = t;
for(int k = 1;k <= 16;k ++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b){
if(depth[a] < depth[b]) swap(a, b);
for(int k = 16; k >= 0;k --)
if(depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 16;k >= 0;k--)
if(fa[a][k] != fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int dfs(int u, int father){
int res = d[u];
for(int i = h[u]; ~i ;i = ne[i]){
int j = e[i];
if(j != father){
int s = dfs(j, u);
if(s == 0) ans += m;
else if(s == 1) ans ++;
res += s;
}
}
return res;
}
int main(){
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for(int i = 0;i < n-1;i ++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs();
for(int i = 0;i < m;i ++){
int a, b;
scanf("%d%d",&a,&b);
int p = lca(a, b);
d[a] ++, d[b] ++, d[p] -= 2;
}
dfs(1, -1);
printf("%d\n",ans);
return 0;
}