[树上倍增] LCA问题与ST表的使用

先理解LCA模板

    7月伊始,这次决定好好学习一下LCA以及ST表的使用,给暑假开个好头。所以呢,先从解决LCA问题说起。
   
   在有根树中,两个结点 u , v u, v u,v 的公共祖先中距离最近的那个被称为最近公共祖先(LCA, Lowest Common Ancestor)。比如在下图中,根是结点8,LCA(3, 11) = 10, LCA(15, 12) = 4。
在这里插入图片描述
   在有根树中由于根的存在,结点与根结点的最短路径长度为这个结点的深度(depth)。那么,如果 L C A ( u , v ) = w LCA(u,v) = w LCA(u,v)=w, 让 u u u 向上走 d e p t h ( u ) − d e p t h ( w ) depth(u) - depth(w) depth(u)depth(w) 步,让 v v v 向上走 d e p t h ( v ) − d e p t h ( w ) depth(v) - depth(w) depth(v)depth(w) 步,就都可以走到 w w w。因此,首先让 u , v u, v u,v 中较深的一方走 ∣ d e p t h ( u ) − d e p t h ( v ) ∣ |depth(u) - depth(v)| depth(u)depth(v) 步,再一起一步一步向上走,直到走到同一个结点,就可以在线性时间里求出LCA。
   
   而有没有更优的算法呢?ST表可以做到!
   我们可以用父结点的信息推出子节点的信息,如 p a r e n t 2 [ v ] = p a r e n t [ p a r e n t [ v ] ] parent2[v] = parent[parent[v]] parent2[v]=parent[parent[v]], p a r e n t 4 [ v ] = p a r e n t 2 [ p a r e n t 2 [ v ] ] parent4[v] = parent2[parent2[v]] parent4[v]=parent2[parent2[v]],若我们记录 f [ i ] [ k ] f[i][k] f[i][k] i i i 结点往上的第 2 k 2^k 2k 个结点,我们可以得到以下的转移关系:
f [ i ] [ 0 ] = p a r e n t [ i ] ,       f [ i ] [ j ] = f [ i ] [ f [ i ] [ j − 1 ] ] ( 如 果 f [ i ] [ j − 1 ] 存 在 ) f[i][0] = parent[i], \ \ \ \ \ f[i][j] = f[i][f[i][j-1]](如果f[i][j-1]存在) f[i][0]=parent[i],     f[i][j]=f[i][f[i][j1]](f[i][j1])
   所以我们在一开始时,可以预处理得到每个结点的深度与父结点 f [ i ] [ 0 ] f[i][0] f[i][0], 而类似于区间 d p dp dp,我们也可以通过两重循环,其中第一重为区间长度,因为长度长的区间必须通过短区间合并而来,第二重为各个结点

void dfs(int x, int fa, int d)     //通过dfs预处理所有结点的父结点与深度
{
    f[x][0] = fa;              //x的父结点即为f[x][0]
    depth[x] = d;              //得到长度
    for(unsigned int i = 0; i < G[x].size(); i++)
        if(G[x][i] != fa)
            dfs(G[x][i], x, d + 1);
}

dfs(root, 0, 0);
for(int j = 1; j < MAX_LOG_V; j++)    //先循环步数
{
	for(int i = 1; i <= n; i++)   //再循环结点
	if(f[i][j-1])            //如果f[i][j-1]存在
		f[i][j] = f[f[i][j-1]][j-1];
}

   
   而得到了 st表,该如何求解LCA呢?首先第一步还是让 u , v u, v u,v走到同一深度,由于 st 表的第二维是指数级别的增长,因此相比于线性时间的一步步地上升,我们可以用对数时间迅速地上升至同一高度
   在上升到同一高度之后,原来的方法是两个结点同时一步步地上升直到到达同一个结点,而我们现在可以采取二分的方式。比如两个结点距离 LCA 的距离为 15, 若是同时上升 16,则会超过 LCA,于是不上升;同时上升 8,距离变成 7,再上升4,距离变成3,再上升2,距离变成1,最后一起上升1,到达 LCA。所以这里步数是从大到小进行遍历,这个处理技巧是非常关键的:

int lca(int u, int v)
{
    if(depth[v] > depth[u]) swap(u, v);       //让u的深度更大

    for(int i = 0; i <= MAX_LOG_V; i++)                   //让u,v到达统一深度
        if((depth[u] - depth[v]) >> i & 1)
            u = f[u][i];

    if(u == v) return u;
    for(int k = MAX_LOG_V; k >= 0; k--)         //利用二分(st表)来计算LCA
    {
        if(f[u][k] != f[v][k])         //若是不会超出LCA就上升
        {
            u = f[u][k];
            v = f[v][k];
        }
    }
    return f[u][0];
}

   
   这样通过预处理 O ( n l o g n ) O(nlogn) O(nlogn), 每次查询 LCA 的时间就可以变成 O ( l o g n ) O(logn) O(logn) 了, 所以这里面倍增的思想很关键。
   附上POJ的模板题以及AC代码 link

//#include <bits/stdc++.h>
#include <stdio.h>
#include <vector>
#include <string.h>

using namespace std;
typedef long long ll;
const int maxn = 1e4 + 10;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;

vector<int> G[maxn];
int fa[maxn], depth[maxn], f[maxn][15], n, root;

void dfs(int x, int fa, int d)     //通过dfs预处理所有结点的父结点与深度
{
    f[x][0] = fa;              //x的父结点即为f[x][0]
    depth[x] = d;              //得到长度
    for(unsigned int i = 0; i < G[x].size(); i++)
        if(G[x][i] != fa)
            dfs(G[x][i], x, d + 1);
}

int lca(int u, int v)
{
    if(depth[v] > depth[u]) swap(u, v);       //让u的深度更大

    for(int i = 0; i < 15; i++)                   //让u,v到达统一深度
        if((depth[u] - depth[v]) >> i & 1)
            u = f[u][i];

    if(u == v) return u;
    for(int k = 14; k >= 0; k--)         //利用二分(st表)来计算LCA
    {
        if(f[u][k] != f[v][k])
        {
            u = f[u][k];
            v = f[v][k];
        }
    }
    return f[u][0];
}

int main()
{
    int t;
    scanf("%d", &t);
    while(t--)
    {
        int x, y;
        scanf("%d", &n);
        memset(fa, 0, sizeof(fa));
        memset(f, 0, sizeof(f));
        for(int i = 1; i <= n; i++)
            G[i].clear();

        for(int i = 1; i < n; i++)
        {
            scanf("%d %d", &x, &y);
            G[x].push_back(y);
            G[y].push_back(x);
            fa[y] = x;
        }
        for(int i = 1; i <= n; i++)      //找到root
            if(!fa[i])
            {
                root = i;
                break;
            }

        dfs(root, 0, 0);
        for(int j = 1; j < 15; j++)    //先循环步数
        {
            for(int i = 1; i <= n; i++)   //再循环结点
                if(f[i][j-1])
                    f[i][j] = f[f[i][j-1]][j-1];
        }

        scanf("%d %d", &x, &y);
        printf("%d\n", lca(x, y));
    }
}

   

ST表的用处

   在上面,用到了一种很厉害的数据结构 ST表,其实这也是动态规划的一种,用统计的思想来解决问题。在上面的有根树之中,若要记录每个结点 i i i 往上 k k k 步是哪一个结点,则需要一个二维数组,我们可以进行对比:

普通二维数组ST表
空间复杂度 O ( V 2 ) O(V^2) O(V2) O ( V l o g ( d e p ) ) O(Vlog(dep)) O(Vlog(dep))
时间复杂度 O ( V 2 ) O(V^2) O(V2) O ( V l o g ( d e p ) ) O(Vlog(dep)) O(Vlog(dep))
查询时间 O ( 1 ) O(1) O(1) O ( l o g ( d e p ) ) O(log(dep)) O(log(dep))

   因此 ST表虽然没有记录全部的信息,但是却可以在对数时间里得到我们想要的信息,有点像线段树,树状数组。ST表一个很大的用处是求解区间的最大/小值。
   比如求最大值,对于数组 a [ i ] a[i] a[i], d p [ i ] [ j ] dp[i][j] dp[i][j] 表示从下标 i i i 开始长度为 2 j 2^j 2j 的区间的最大值,可以得出 d p [ i ] [ 0 ] = a [ 0 ] dp[i][0] = a[0] dp[i][0]=a[0], 而 j j j 每增加1,就说明区间长度增倍,则有以下转移式:
d p [ i ] [ j ] = m a x ( d p [ i ] [ j − 1 ] , d p [ i + 2 j − 1 ] [ j − 1 ] dp[i][j] = max(dp[i][j-1], dp[i + 2^{j-1}][j-1] dp[i][j]=max(dp[i][j1],dp[i+2j1][j1]
   这样子就可以在 O ( n l o g n ) O(nlogn) O(nlogn) 的时间里维护出 ST表。那该如何求出某个区间 [ L , R ] [L, R] [L,R] 的最大值呢?比如区间 [ 3 , 8 ] [3, 8] [3,8],长度虽然不是2的幂次,但可以通过分别查找区间 [ 3 , 6 ] [3, 6] [3,6] 和 区间 [ 5 , 8 ] [5 ,8] [5,8],即 d p [ 3 ] [ 2 ] dp[3][2] dp[3][2] d p [ 5 ] [ 2 ] dp[5][2] dp[5][2] 得到。所以两个长度大于等于 ( R − L + 1 ) / 2 (R - L + 1)/2 (RL+1)/2 的区间 (一个左边界为 L L L, 一个右边界为 R R R) 可以覆盖我们要求的区间。

void init_st(int n)
{
    for(int i = 1; i <= n; i++)  //区间跨度为1
        f[i][0] = a[i];
    int m = (int)(log((double)n) / log(2.0));   //算出最大跨度2^m

    for(int j = 1; j <= m; j++)               //第一维为区间长度
        for(int i = 1; i + (1<<j) - 1 <= n; i++)
            f[i][j] = max(f[i][j-1], f[i+(1<<(j-1))][j-1]);
}

int Find(int l, int r)
{
    int k = (int)(log(1.0 * (r-l+1)) / log(2.0));     //求出两个区间的长度
    return max(f[l][k], f[r-(1<<k)+1][k]);
}

   
   下面来看一道 ST表稍微难一点的运用:link,大意是给你一个长为 n n n 的序列和一个常数 k k k m m m 次询问,每次查询一个区间 [ l , r ] [l,r] [l,r] 内所有数最少分成多少个连续段,使得每段的和都 ≤ k ≤ k k,如果这一次查询无解,输出 “Chtholly”,其中 ( n , m ≤ 1 e 6 n, m≤ 1e6 n,m1e6)
   由于询问次数很多,只能把每次询问时间压到对数或者常数级别,而我们可以用 st 表来记录这些信息。若记 d p [ i ] [ 0 ] dp[i][0] dp[i][0] 表示区间从当前开始,下一个区间最晚的位置,即 [ i , d p [ i ] [ 0 ] − 1 ] [i, dp[i][0] - 1] [i,dp[i][0]1] 求和小于等于 k k k, [ i , d p [ i ] [ 0 ] ] [i, dp[i][0]] [i,dp[i][0]] 大于 k k k, 那么相应地 d p [ i ] [ j ] dp[i][j] dp[i][j] 就是下 2 j 2^j 2j 个区间最晚开始的位置,我们可以得到以下转移方程:
d p [ i ] [ j ] = d p [ d p [ i ] [ j − 1 ] ] [ j − 1 ] dp[i][j] = dp[dp[i][j-1]][j-1] dp[i][j]=dp[dp[i][j1]][j1]
   而这里需要考虑到两种情况:
   ①若是有一个数大于 k k k, 那么这个值是无法在区间里的,在处理时就非常麻烦,所以在读入数据时,我们可以标记哪些位置的数值大于 k k k, 然后记录前缀和,这样就可以快速判断一个区间里是不是有大于 k k k 的数了,然后我们 st 表只处理小于等于 k k k 的数。
   ②若是区间可以包含到最后一个数都不超过 k k k,也需要特殊处理,我就让 d p [ i ] [ j ] = 0 dp[i][j] = 0 dp[i][j]=0 表示可以扩展到最后一个元素。
   而求 d p [ i ] [ 0 ] dp[i][0] dp[i][0], 可以用尺取法维护一个区间,其他细节可以参照代码来看:

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int maxn = 1e6 + 3;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;

int a[maxn], b[maxn], f[maxn][21], n, m, k, tol;

int main()
{
    scanf("%d %d %d", &n, &m, &k);
    for(int i = 1; i <= n; i++)
    {
        scanf("%d", &a[++tol]); 
        if(a[tol] > k)       //只保留小于等于k的数
        {
            b[i] = 1;     //超过k的下标用b数组标记
            tol--;
        }
    }
    for(int i = 2; i <= n; i++)    //求前缀和,现在b[i]表示[1,i]超过k的数
        b[i] += b[i-1];

    int l = 1, r = 1;             
    ll tmp = 0;
    while(l <= tol)
    {
        while(r <= n && tmp + a[r] <= k)     //[l, r)区间和小于等于k,找到r的最大值
        {
            tmp += a[r];
            r++;
        }

        if(r == n + 1)  f[l][0] = 0;         //如果剩下的所有元素都被包含了,用0表示
        else f[l][0] = r;
        tmp -= a[l];
        l++;
    }

    for(int j = 1; j <= 20; j++)
        for(int i = 1; i + (1<<j) <= n; i++)
        {
            if(f[i][j-1] == 0 || f[f[i][j-1]][j-1] == 0)  //有这个条件可以直接break
                break;
            f[i][j] = f[f[i][j-1]][j-1];
        }

    for(int cas = 1; cas <= m; cas++)
    {
        int ans = 0;
        scanf("%d %d", &l, &r);
        if(b[r] - b[l-1] > 0)       //如果区间包含有大于k的元素
        {
            printf("Chtholly\n");
            continue;
        }
        r -= b[r];  l -= b[l-1];          //得到现在的下标
        for(int j = 20; j >= 0; j--)
        {
            if(f[l][j] <= r && f[l][j] > 0)
            {
                ans += 1<<j;
                l = f[l][j];
            }
        }
        printf("%d\n", ans + 1);

    }

}

   

ST表比较灵活的使用

   还是从题目开始入手吧,其实我一开始是做到了接下来的这道题,然后才了解了 ST表和树上倍增…所以是通过题来花时间了解数据结构,然后最后一步当然是看看这题是如何运用的啦。
   link有一个树状的城市网络(即 n n n 个城市由 n − 1 n-1 n1 条道路连接的连通图),首都为 1 号城市,每个城市售卖价值为 a i a_i ai 的珠宝。你是一个珠宝商,现在安排有 q q q 次行程,每次行程为从 u u u 号城市前往 v v v 号城市(走最短路径),保证 v v v u u u 前往首都的最短路径上。 在每次行程开始时,你手上有价值为 c c c 的珠宝(每次行程可能不同),并且每经过一个城市时(包括 u u u v v v ),假如那个城市中售卖的珠宝比你现在手上的每一种珠宝都要优秀(价值更高,即严格大于),那么你就会选择购入。现在你想要对每一次行程,求出会进行多少次购买事件。其中, 2 ≤ n ≤ 1 0 5 , 1 ≤ q ≤ 1 0 5 2 ≤ n ≤ 10^5 , 1 ≤ q ≤ 10^5 2n105,1q105
   理解一下题意,每个结点都有一个权重,然后每次出发时都给你一个值,并保证你是朝根方向往上走,若是到达一个结点,它的权重比所有经过结点权重以及手上的初值要大,就计数 +1。也就是说,如果你在某个结点处买了珠宝,那么这个结点的权重就是你拥有的最大值,那么从哪里走到这个结点不会影响后续的情况,所以我们不妨记录一下这个信息。
    让 d p [ i ] [ j ] dp[i][j] dp[i][j] 表示若在结点 i i i 买了珠宝,往根向上走,再买到第 2 j 2^j 2j 个珠宝的结点,我们可以得到:
d p [ i ] [ j ] = d p [ d p [ i ] [ j − 1 ] ] [ j − 1 ] dp[i][j] = dp[dp[i][j-1]][j-1] dp[i][j]=dp[dp[i][j1]][j1]
   这里同时有些特殊情况要处理:
   ① d p [ i ] [ j ] dp[i][j] dp[i][j] 可能并不存在,比如对于根结点 1 1 1, d p [ 1 ] [ j ] dp[1][j] dp[1][j] 就是不存在的,不妨我们记 d p [ i ] [ j ] = 0 dp[i][j] = 0 dp[i][j]=0,表示在 i i i 结点出发,买到第 2 j 2^j 2j 件珠宝之前,就已经到达了根结点;
   ② d p [ i ] [ 0 ] dp[i][0] dp[i][0] 是表示下一个买珠宝的城市,这个在动态规划前需要先求出来。如果它的父节点 f a fa fa 权重大于 i i i, 那么 d p [ i ] [ 0 ] = f a dp[i][0] = fa dp[i][0]=fa, 若小于它,该怎么求呢?比如举例,父结点之后买的 6 6 6 件珠宝权重都小于等于结点 i i i, 直到第 7 7 7 个结点,我们可以从大步长开始跨越,那么我们分别跨越 4 , 2 , 1 4, 2, 1 421 个步长到达这个结点(要是难以理解,可以看具体代码的实现)
   
   而现在题目要求的是带着珠宝 c c c 从某个结点 u u u 出发,到达 u u u, 而可以让 u u u 多出一个儿子 u ′ u' u u ′ u' u 权重为 c c c,这样我们就是要求 d p [ u ′ ] [ v ] dp[u'][v] dp[u][v], 求的时候我们也是从大步长开始跨越,可以通过在树中的深度来判断有没有超过目标结点
   为什么要从大步长开始呢?这其实是 st表中一个关键的处理,比如步长相差6,从小步长就会先跨越 1 1 1, 再跨越 2 2 2, 但是 4 4 4 就不行了,这样会凑不齐,而我们从大步长跨越,就会先跨越 4 4 4, 再跨越 2 2 2,这样可以凑出 6 6 6 出来。
   
   下面就是具体的代码实现咯,总结一下, st表关键之处是给出适合题目的定义,题目的解可以用 d p [ i ] [ j ] dp[i][j] dp[i][j] 用对数时间求出来,然后预处理时要求出 d p [ i ] [ 0 ] dp[i][0] dp[i][0]要想清楚什么地方步长要从大到小开始遍历

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;

int a[maxn], depth[maxn], to[maxn], f[maxn][16], n, q;
vector<int> G[maxn];

void dfs(int x, int fa, int step)
{
    int cur = fa;                        //先预处理出dp[x][0]
    for(int j = 15; j >= 0; j--)
        if(f[cur][j] && a[f[cur][j]] <= a[x])       
            cur = f[cur][j];
    if(a[cur] > a[x])          //这种情况只会在父节点大于子节点情况发生
        f[x][0] = cur;
    else
        f[x][0] = f[cur][0];             

    for(int j = 1; j <= 15; j++)               //进行动态规划
        f[x][j] = f[f[x][j-1]][j-1];

    depth[x] = step;              
    for(auto y: G[x])
        if(y != fa)
            dfs(y, x, step + 1);
}

int main()
{
    scanf("%d %d", &n, &q);  
    for(int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    int x, y;
    for(int i = 1; i < n; i++)
    {
        scanf("%d %d", &x, &y);
        G[x].push_back(y);
        G[y].push_back(x);
    }
    for(int i = n + 1; i <= n + q; i++)          //给每一个出发结点多一个儿子
    {
        scanf("%d %d %d", &x, &to[i-n], &a[i]);
        G[i].push_back(x);
        G[x].push_back(i);
    }

    dfs(1, 0, 0);

    for(int cas = 1; cas <= q; cas++)           
    {
        int ans = 0, x = cas + n, y = to[cas];
        for(int i = 15; i >= 0; i--)
            if(f[x][i] > 0 && depth[f[x][i]] >= depth[y])     //用深度来判断有没有超过目标结点
            {
                ans += 1<<i;
                x = f[x][i];
            }
        printf("%d\n", ans);
    }
}

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值