第二篇博客我们今天来讲讲LCA,其实我也是刚刚学会的。效果还不错。
LCA一般有三种算法:倍增、RMQ、tarjan。
我们先来聊聊倍增,建图O(nlogn),查询O(logn),代码简单易懂,实际效率也不错。
我们就以code[vs]2370 小机房的树为例吧!
2370 小机房的树
小机房有棵焕狗种的树,树上有N个节点,节点标号为0到N-1,有两只虫子名叫飘狗和大吉狗,分居在两个不同的节点上。有一天,他们想爬到一个节点上去搞基,但是作为两只虫子,他们不想花费太多精力。已知从某个节点爬到其父亲节点要花费 c 的能量(从父亲节点爬到此节点也相同),他们想找出一条花费精力最短的路,以使得搞基的时候精力旺盛,他们找到你要你设计一个程序来找到这条路,要求你告诉他们最少需要花费多少精力
第一行一个n,接下来n-1行每一行有三个整数u,v, c 。表示节点 u 爬到节点 v 需要花费 c 的精力。
第n+1行有一个整数m表示有m次询问。接下来m行每一行有两个整数 u ,v 表示两只虫子所在的节点
一共有m行,每一行一个整数,表示对于该次询问所得出的最短距离。
3
1 0 1
2 0 1
3
1 0
2 0
1 2
1
1
2
1<=n<=50000, 1<=m<=75000, 0<=c<=1000
这就是典型的LCA裸题。我们可以这样想,fi,j表示i往上走2^j层的点,gi,j表示i往上走2^j层的费用,不难想到递推方程:f[i][j]=f[f[i][j-1]][j-1],gi,j就请自己YY一下了。这可以和树的深度一起在dfs中预处理。
至于查询,我们就可下用倍增(位运算)将a,b跳到相同的深度(跳a,b中深度深的),之后一起跳(只要f[a][i]!=f[b][i]就可以跳)。
代码如下:
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cstring> 4 using namespace std; 5 #define maxn 50010 6 int n,m,cnt; 7 int deep[maxn],f[maxn][32],g[maxn][32];//deep表一个点的深度 8 int side[maxn],next[maxn*2],toit[maxn*2],cost[maxn*2];//边集数组 9 bool vis[maxn];//防止点被走重 10 inline void add(int a,int b,int c) 11 { 12 cnt++; 13 next[cnt] = side[a]; 14 side[a] = cnt; 15 toit[cnt] = b; 16 cost[cnt] = c; 17 } 18 19 inline void dfs(int a,int dep) 20 { 21 vis[a] = true; 22 int i,t; 23 deep[a] = dep; 24 i = 1; 25 while ((1<<i) <= dep)//递推 26 { 27 f[a][i] = f[f[a][i-1]][i-1]; 28 g[a][i] = g[f[a][i-1]][i-1]+g[a][i-1]; 29 i++; 30 } 31 for (i = side[a];i;i=next[i]) 32 { 33 t = toit[i]; 34 if (!vis[t]) 35 f[t][0] = a,g[t][0] = cost[i],dfs(t,dep+1); 36 } 37 } 38 39 inline int jump(int &a,int step) 40 { 41 int i = 0,ret = 0; 42 while (step > 0) 43 { 44 if (step & 1) 45 ret +=g[a][i],a = f[a][i]; 46 i++; 47 step >>= 1; 48 } 49 return ret; 50 } 51 52 inline int work(int a,int b) 53 { 54 int ret = 0; 55 if (deep[a] > deep[b]) 56 ret += jump(a,deep[a]-deep[b]); 57 else ret += jump(b,deep[b]-deep[a]); 58 int i = 0; 59 while (i >= 0) 60 { 61 if (f[a][i] != f[b][i]) 62 { 63 ret += g[a][i]; 64 ret += g[b][i]; 65 a = f[a][i]; 66 b = f[b][i]; 67 i++; 68 } 69 else 70 i--; 71 } 72 if (a != b) 73 ret = ret + g[a][0] + g[b][0]; 74 return ret; 75 } 76 77 int main() 78 { 79 scanf("%d",&n); 80 int i,a,b,c; 81 for (i = 1;i<n;i++) 82 scanf("%d %d %d",&a,&b,&c),add(a,b,c),add(b,a,c); 83 dfs(0,0); 84 scanf("%d",&m); 85 while (m) 86 { 87 m--; 88 scanf("%d %d",&a,&b); 89 printf("%d\n",work(a,b)); 90 } 91 return 0; 92 }
然后,我们来看看RMQ的算法。 RMQ 中最经典的是 ST 算法。
ST 算法,基于如下原理:
令 f[i][j] 表示 从 i 开始,连续 2^j 个最大值的大小。
我们不难得出如下递推公式: f[i][j] = min( f[i][j-1] , f[i + (1 << (j - 1))][j-1] );
这样 f 数组可以在 NlogN 时间内算出。那么我们可以在 O(1) 的时间内回答每个询问,但是我们需要调用 cmath 库中的 Log2(x) 函数 。。。。。 导致这个玩意儿甚至比上面的 Log(n) 算法的常数更大一点。
好吧,其实我们跑题了。。。。。
现在来看一看为什么 LCA 问题可以转变成 RMQ 问题。
我们先对一个树进行 DFS 遍历,记录每个点入栈和退栈的次序,我们称此序列为欧拉序列。举个例子来说就是这样子的!
(1)
/ \
(2) (7)
/ \ \
(3) (4) (8)
/ \
(5) (6)
一个nlogn 预处理,O(1)查询的算法.
Step 1:
按先序遍历整棵树,记下两个信息:结点访问顺序和结点深度.
如上图:
结点访问顺序是: 1 2 3 2 4 5 4 6 4 2 1 7 8 7 1 //共2n-1个值
结点对应深度是: 0 1 2 1 2 3 2 3 2 1 0 1 2 1 0
Step 2:
如果查询结点3与结点6的公共祖先,则考虑在访问顺序中
3第一次出现,到6第一次出现的子序列: 3 2 4 5 4 6.
这显然是由结点3到结点6的一条路径.
在这条路径中,深度最小的就是最近公共祖先(LCA). 即
结点2是3和6的LCA.
Step 3:
于是问题转化为, 给定一个数组R,及两个数字i,j,如何找出
数组R中从i位置到j位置的最小值..
如上例,就是R[]={0,1,2,1,2,3,2,3,2,1,0,1,2,1,0}.
i=2;j=7;
接下来就是经典的RMQ问题.
我们看一看这种方法的代码:
1 #include<cstdlib> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cstring> 5 #include<cmath> 6 #define maxn 111000 7 using namespace std; 8 int n,m,cnt; 9 int last[maxn]; 10 struct E{ 11 int to,cost,next; 12 }e[maxn]; 13 int pre[maxn],pos[maxn],dep[maxn]; 14 int rmq[maxn][32]; 15 inline void add(int u,int v,int c) 16 { 17 e[++cnt] = (E){v,c,last[u]}; 18 last[u] = cnt; 19 e[++cnt] = (E){u,c,last[v]}; 20 last[v] = cnt; 21 } 22 void dfs(int x,int fa,int deep) 23 { 24 pre[++cnt] = x; 25 pos[x] = cnt; 26 dep[cnt] = deep; 27 for(int i=last[x];i;i = e[i].next){ 28 int to = e[i].to; 29 if(to == fa) continue; 30 dfs(to,x,deep+e[i].cost); 31 pre[++cnt] = x; 32 dep[cnt] = deep; 33 } 34 } 35 void st() 36 { 37 memset(rmq,0x3f,sizeof(rmq)); 38 for(int i=1;i<=cnt;++i) 39 rmq[i][0] = dep[i]; 40 for(int j=1;(1<<j)<=cnt;++j) 41 for(int i=1;i + (1<<j) - 1<=cnt;++i) 42 rmq[i][j] = min(rmq[i][j-1], rmq[i+(1<<(j-1))][j-1]); 43 } 44 int query(int l,int r) 45 { 46 int k = log2(r-l+1); 47 return min(rmq[l][k],rmq[r-(1<<k)+1][k]); 48 } 49 int main() 50 { 51 scanf("%d",&n); 52 for(int i=1;i<n;++i){ 53 int u,v,c; 54 scanf("%d %d %d",&u,&v,&c); 55 u++;v++; 56 add(u,v,c); 57 } 58 cnt = 0; 59 dfs(1,0,0); 60 st(); 61 scanf("%d",&m); 62 for(int i=1;i<=m;++i){ 63 int u,v; 64 scanf("%d %d",&u,&v); 65 u++;v++; 66 if(pos[u] > pos[v]) 67 swap(u,v); 68 int res = query(pos[u],pos[v]); 69 printf("%d\n",dep[pos[u]] + dep[pos[v]] - res * 2); 70 } 71 return 0; 72 }