先理解LCA模板
7月伊始,这次决定好好学习一下LCA以及ST表的使用,给暑假开个好头。所以呢,先从解决LCA问题说起。
在有根树中,两个结点
u
,
v
u, v
u,v 的公共祖先中距离最近的那个被称为最近公共祖先(LCA, Lowest Common Ancestor)。比如在下图中,根是结点8,LCA(3, 11) = 10, LCA(15, 12) = 4。
在有根树中由于根的存在,结点与根结点的最短路径长度为这个结点的深度(depth)。那么,如果
L
C
A
(
u
,
v
)
=
w
LCA(u,v) = w
LCA(u,v)=w, 让
u
u
u 向上走
d
e
p
t
h
(
u
)
−
d
e
p
t
h
(
w
)
depth(u) - depth(w)
depth(u)−depth(w) 步,让
v
v
v 向上走
d
e
p
t
h
(
v
)
−
d
e
p
t
h
(
w
)
depth(v) - depth(w)
depth(v)−depth(w) 步,就都可以走到
w
w
w。因此,首先让
u
,
v
u, v
u,v 中较深的一方走
∣
d
e
p
t
h
(
u
)
−
d
e
p
t
h
(
v
)
∣
|depth(u) - depth(v)|
∣depth(u)−depth(v)∣ 步,再一起一步一步向上走,直到走到同一个结点,就可以在线性时间里求出LCA。
而有没有更优的算法呢?ST表可以做到!
我们可以用父结点的信息推出子节点的信息,如
p
a
r
e
n
t
2
[
v
]
=
p
a
r
e
n
t
[
p
a
r
e
n
t
[
v
]
]
parent2[v] = parent[parent[v]]
parent2[v]=parent[parent[v]],
p
a
r
e
n
t
4
[
v
]
=
p
a
r
e
n
t
2
[
p
a
r
e
n
t
2
[
v
]
]
parent4[v] = parent2[parent2[v]]
parent4[v]=parent2[parent2[v]],若我们记录
f
[
i
]
[
k
]
f[i][k]
f[i][k] 为
i
i
i 结点往上的第
2
k
2^k
2k 个结点,我们可以得到以下的转移关系:
f
[
i
]
[
0
]
=
p
a
r
e
n
t
[
i
]
,
f
[
i
]
[
j
]
=
f
[
i
]
[
f
[
i
]
[
j
−
1
]
]
(
如
果
f
[
i
]
[
j
−
1
]
存
在
)
f[i][0] = parent[i], \ \ \ \ \ f[i][j] = f[i][f[i][j-1]](如果f[i][j-1]存在)
f[i][0]=parent[i], f[i][j]=f[i][f[i][j−1]](如果f[i][j−1]存在)
所以我们在一开始时,可以预处理得到每个结点的深度与父结点即
f
[
i
]
[
0
]
f[i][0]
f[i][0], 而类似于区间
d
p
dp
dp,我们也可以通过两重循环,其中第一重为区间长度,因为长度长的区间必须通过短区间合并而来,第二重为各个结点:
void dfs(int x, int fa, int d) //通过dfs预处理所有结点的父结点与深度
{
f[x][0] = fa; //x的父结点即为f[x][0]
depth[x] = d; //得到长度
for(unsigned int i = 0; i < G[x].size(); i++)
if(G[x][i] != fa)
dfs(G[x][i], x, d + 1);
}
dfs(root, 0, 0);
for(int j = 1; j < MAX_LOG_V; j++) //先循环步数
{
for(int i = 1; i <= n; i++) //再循环结点
if(f[i][j-1]) //如果f[i][j-1]存在
f[i][j] = f[f[i][j-1]][j-1];
}
而得到了 st表,该如何求解LCA呢?首先第一步还是让
u
,
v
u, v
u,v走到同一深度,由于 st 表的第二维是指数级别的增长,因此相比于线性时间的一步步地上升,我们可以用对数时间迅速地上升至同一高度。
在上升到同一高度之后,原来的方法是两个结点同时一步步地上升直到到达同一个结点,而我们现在可以采取二分的方式。比如两个结点距离 LCA 的距离为 15, 若是同时上升 16,则会超过 LCA,于是不上升;同时上升 8,距离变成 7,再上升4,距离变成3,再上升2,距离变成1,最后一起上升1,到达 LCA。所以这里步数是从大到小进行遍历,这个处理技巧是非常关键的:
int lca(int u, int v)
{
if(depth[v] > depth[u]) swap(u, v); //让u的深度更大
for(int i = 0; i <= MAX_LOG_V; i++) //让u,v到达统一深度
if((depth[u] - depth[v]) >> i & 1)
u = f[u][i];
if(u == v) return u;
for(int k = MAX_LOG_V; k >= 0; k--) //利用二分(st表)来计算LCA
{
if(f[u][k] != f[v][k]) //若是不会超出LCA就上升
{
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
这样通过预处理
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn), 每次查询 LCA 的时间就可以变成
O
(
l
o
g
n
)
O(logn)
O(logn) 了, 所以这里面倍增的思想很关键。
附上POJ的模板题以及AC代码 link
//#include <bits/stdc++.h>
#include <stdio.h>
#include <vector>
#include <string.h>
using namespace std;
typedef long long ll;
const int maxn = 1e4 + 10;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;
vector<int> G[maxn];
int fa[maxn], depth[maxn], f[maxn][15], n, root;
void dfs(int x, int fa, int d) //通过dfs预处理所有结点的父结点与深度
{
f[x][0] = fa; //x的父结点即为f[x][0]
depth[x] = d; //得到长度
for(unsigned int i = 0; i < G[x].size(); i++)
if(G[x][i] != fa)
dfs(G[x][i], x, d + 1);
}
int lca(int u, int v)
{
if(depth[v] > depth[u]) swap(u, v); //让u的深度更大
for(int i = 0; i < 15; i++) //让u,v到达统一深度
if((depth[u] - depth[v]) >> i & 1)
u = f[u][i];
if(u == v) return u;
for(int k = 14; k >= 0; k--) //利用二分(st表)来计算LCA
{
if(f[u][k] != f[v][k])
{
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
int main()
{
int t;
scanf("%d", &t);
while(t--)
{
int x, y;
scanf("%d", &n);
memset(fa, 0, sizeof(fa));
memset(f, 0, sizeof(f));
for(int i = 1; i <= n; i++)
G[i].clear();
for(int i = 1; i < n; i++)
{
scanf("%d %d", &x, &y);
G[x].push_back(y);
G[y].push_back(x);
fa[y] = x;
}
for(int i = 1; i <= n; i++) //找到root
if(!fa[i])
{
root = i;
break;
}
dfs(root, 0, 0);
for(int j = 1; j < 15; j++) //先循环步数
{
for(int i = 1; i <= n; i++) //再循环结点
if(f[i][j-1])
f[i][j] = f[f[i][j-1]][j-1];
}
scanf("%d %d", &x, &y);
printf("%d\n", lca(x, y));
}
}
ST表的用处
在上面,用到了一种很厉害的数据结构 ST表,其实这也是动态规划的一种,用统计的思想来解决问题。在上面的有根树之中,若要记录每个结点 i i i 往上 k k k 步是哪一个结点,则需要一个二维数组,我们可以进行对比:
普通二维数组 | ST表 | |
---|---|---|
空间复杂度 | O ( V 2 ) O(V^2) O(V2) | O ( V l o g ( d e p ) ) O(Vlog(dep)) O(Vlog(dep)) |
时间复杂度 | O ( V 2 ) O(V^2) O(V2) | O ( V l o g ( d e p ) ) O(Vlog(dep)) O(Vlog(dep)) |
查询时间 | O ( 1 ) O(1) O(1) | O ( l o g ( d e p ) ) O(log(dep)) O(log(dep)) |
因此 ST表虽然没有记录全部的信息,但是却可以在对数时间里得到我们想要的信息,有点像线段树,树状数组。ST表一个很大的用处是求解区间的最大/小值。
比如求最大值,对于数组
a
[
i
]
a[i]
a[i],
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示从下标
i
i
i 开始长度为
2
j
2^j
2j 的区间的最大值,可以得出
d
p
[
i
]
[
0
]
=
a
[
0
]
dp[i][0] = a[0]
dp[i][0]=a[0], 而
j
j
j 每增加1,就说明区间长度增倍,则有以下转移式:
d
p
[
i
]
[
j
]
=
m
a
x
(
d
p
[
i
]
[
j
−
1
]
,
d
p
[
i
+
2
j
−
1
]
[
j
−
1
]
dp[i][j] = max(dp[i][j-1], dp[i + 2^{j-1}][j-1]
dp[i][j]=max(dp[i][j−1],dp[i+2j−1][j−1]
这样子就可以在
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn) 的时间里维护出 ST表。那该如何求出某个区间
[
L
,
R
]
[L, R]
[L,R] 的最大值呢?比如区间
[
3
,
8
]
[3, 8]
[3,8],长度虽然不是2的幂次,但可以通过分别查找区间
[
3
,
6
]
[3, 6]
[3,6] 和 区间
[
5
,
8
]
[5 ,8]
[5,8],即
d
p
[
3
]
[
2
]
dp[3][2]
dp[3][2] 与
d
p
[
5
]
[
2
]
dp[5][2]
dp[5][2] 得到。所以两个长度大于等于
(
R
−
L
+
1
)
/
2
(R - L + 1)/2
(R−L+1)/2 的区间 (一个左边界为
L
L
L, 一个右边界为
R
R
R) 可以覆盖我们要求的区间。
void init_st(int n)
{
for(int i = 1; i <= n; i++) //区间跨度为1
f[i][0] = a[i];
int m = (int)(log((double)n) / log(2.0)); //算出最大跨度2^m
for(int j = 1; j <= m; j++) //第一维为区间长度
for(int i = 1; i + (1<<j) - 1 <= n; i++)
f[i][j] = max(f[i][j-1], f[i+(1<<(j-1))][j-1]);
}
int Find(int l, int r)
{
int k = (int)(log(1.0 * (r-l+1)) / log(2.0)); //求出两个区间的长度
return max(f[l][k], f[r-(1<<k)+1][k]);
}
下面来看一道 ST表稍微难一点的运用:link,大意是给你一个长为
n
n
n 的序列和一个常数
k
k
k 有
m
m
m 次询问,每次查询一个区间
[
l
,
r
]
[l,r]
[l,r] 内所有数最少分成多少个连续段,使得每段的和都
≤
k
≤ k
≤k,如果这一次查询无解,输出 “Chtholly”,其中 (
n
,
m
≤
1
e
6
n, m≤ 1e6
n,m≤1e6)
由于询问次数很多,只能把每次询问时间压到对数或者常数级别,而我们可以用 st 表来记录这些信息。若记
d
p
[
i
]
[
0
]
dp[i][0]
dp[i][0] 表示区间从当前开始,下一个区间最晚的位置,即
[
i
,
d
p
[
i
]
[
0
]
−
1
]
[i, dp[i][0] - 1]
[i,dp[i][0]−1] 求和小于等于
k
k
k,
[
i
,
d
p
[
i
]
[
0
]
]
[i, dp[i][0]]
[i,dp[i][0]] 大于
k
k
k, 那么相应地
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 就是下
2
j
2^j
2j 个区间最晚开始的位置,我们可以得到以下转移方程:
d
p
[
i
]
[
j
]
=
d
p
[
d
p
[
i
]
[
j
−
1
]
]
[
j
−
1
]
dp[i][j] = dp[dp[i][j-1]][j-1]
dp[i][j]=dp[dp[i][j−1]][j−1]
而这里需要考虑到两种情况:
①若是有一个数大于
k
k
k, 那么这个值是无法在区间里的,在处理时就非常麻烦,所以在读入数据时,我们可以标记哪些位置的数值大于
k
k
k, 然后记录前缀和,这样就可以快速判断一个区间里是不是有大于
k
k
k 的数了,然后我们 st 表只处理小于等于
k
k
k 的数。
②若是区间可以包含到最后一个数都不超过
k
k
k,也需要特殊处理,我就让
d
p
[
i
]
[
j
]
=
0
dp[i][j] = 0
dp[i][j]=0 表示可以扩展到最后一个元素。
而求
d
p
[
i
]
[
0
]
dp[i][0]
dp[i][0], 可以用尺取法维护一个区间,其他细节可以参照代码来看:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 3;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;
int a[maxn], b[maxn], f[maxn][21], n, m, k, tol;
int main()
{
scanf("%d %d %d", &n, &m, &k);
for(int i = 1; i <= n; i++)
{
scanf("%d", &a[++tol]);
if(a[tol] > k) //只保留小于等于k的数
{
b[i] = 1; //超过k的下标用b数组标记
tol--;
}
}
for(int i = 2; i <= n; i++) //求前缀和,现在b[i]表示[1,i]超过k的数
b[i] += b[i-1];
int l = 1, r = 1;
ll tmp = 0;
while(l <= tol)
{
while(r <= n && tmp + a[r] <= k) //[l, r)区间和小于等于k,找到r的最大值
{
tmp += a[r];
r++;
}
if(r == n + 1) f[l][0] = 0; //如果剩下的所有元素都被包含了,用0表示
else f[l][0] = r;
tmp -= a[l];
l++;
}
for(int j = 1; j <= 20; j++)
for(int i = 1; i + (1<<j) <= n; i++)
{
if(f[i][j-1] == 0 || f[f[i][j-1]][j-1] == 0) //有这个条件可以直接break
break;
f[i][j] = f[f[i][j-1]][j-1];
}
for(int cas = 1; cas <= m; cas++)
{
int ans = 0;
scanf("%d %d", &l, &r);
if(b[r] - b[l-1] > 0) //如果区间包含有大于k的元素
{
printf("Chtholly\n");
continue;
}
r -= b[r]; l -= b[l-1]; //得到现在的下标
for(int j = 20; j >= 0; j--)
{
if(f[l][j] <= r && f[l][j] > 0)
{
ans += 1<<j;
l = f[l][j];
}
}
printf("%d\n", ans + 1);
}
}
ST表比较灵活的使用
还是从题目开始入手吧,其实我一开始是做到了接下来的这道题,然后才了解了 ST表和树上倍增…所以是通过题来花时间了解数据结构,然后最后一步当然是看看这题是如何运用的啦。
link有一个树状的城市网络(即
n
n
n 个城市由
n
−
1
n-1
n−1 条道路连接的连通图),首都为 1 号城市,每个城市售卖价值为
a
i
a_i
ai 的珠宝。你是一个珠宝商,现在安排有
q
q
q 次行程,每次行程为从
u
u
u 号城市前往
v
v
v 号城市(走最短路径),保证
v
v
v 在
u
u
u 前往首都的最短路径上。 在每次行程开始时,你手上有价值为
c
c
c 的珠宝(每次行程可能不同),并且每经过一个城市时(包括
u
u
u 和
v
v
v ),假如那个城市中售卖的珠宝比你现在手上的每一种珠宝都要优秀(价值更高,即严格大于),那么你就会选择购入。现在你想要对每一次行程,求出会进行多少次购买事件。其中,
2
≤
n
≤
1
0
5
,
1
≤
q
≤
1
0
5
2 ≤ n ≤ 10^5 , 1 ≤ q ≤ 10^5
2≤n≤105,1≤q≤105。
理解一下题意,每个结点都有一个权重,然后每次出发时都给你一个值,并保证你是朝根方向往上走,若是到达一个结点,它的权重比所有经过结点权重以及手上的初值要大,就计数 +1。也就是说,如果你在某个结点处买了珠宝,那么这个结点的权重就是你拥有的最大值,那么从哪里走到这个结点不会影响后续的情况,所以我们不妨记录一下这个信息。
让
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 表示若在结点
i
i
i 买了珠宝,往根向上走,再买到第
2
j
2^j
2j 个珠宝的结点,我们可以得到:
d
p
[
i
]
[
j
]
=
d
p
[
d
p
[
i
]
[
j
−
1
]
]
[
j
−
1
]
dp[i][j] = dp[dp[i][j-1]][j-1]
dp[i][j]=dp[dp[i][j−1]][j−1]
这里同时有些特殊情况要处理:
①
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 可能并不存在,比如对于根结点
1
1
1,
d
p
[
1
]
[
j
]
dp[1][j]
dp[1][j] 就是不存在的,不妨我们记
d
p
[
i
]
[
j
]
=
0
dp[i][j] = 0
dp[i][j]=0,表示在
i
i
i 结点出发,买到第
2
j
2^j
2j 件珠宝之前,就已经到达了根结点;
②
d
p
[
i
]
[
0
]
dp[i][0]
dp[i][0] 是表示下一个买珠宝的城市,这个在动态规划前需要先求出来。如果它的父节点
f
a
fa
fa 权重大于
i
i
i, 那么
d
p
[
i
]
[
0
]
=
f
a
dp[i][0] = fa
dp[i][0]=fa, 若小于它,该怎么求呢?比如举例,父结点之后买的
6
6
6 件珠宝权重都小于等于结点
i
i
i, 直到第
7
7
7 个结点,我们可以从大步长开始跨越,那么我们分别跨越
4
,
2
,
1
4, 2, 1
4,2,1 个步长到达这个结点(要是难以理解,可以看具体代码的实现)
而现在题目要求的是带着珠宝
c
c
c 从某个结点
u
u
u 出发,到达
u
u
u, 而可以让
u
u
u 多出一个儿子
u
′
u'
u′,
u
′
u'
u′ 权重为
c
c
c,这样我们就是要求
d
p
[
u
′
]
[
v
]
dp[u'][v]
dp[u′][v], 求的时候我们也是从大步长开始跨越,可以通过在树中的深度来判断有没有超过目标结点。
为什么要从大步长开始呢?这其实是 st表中一个关键的处理,比如步长相差6,从小步长就会先跨越
1
1
1, 再跨越
2
2
2, 但是
4
4
4 就不行了,这样会凑不齐,而我们从大步长跨越,就会先跨越
4
4
4, 再跨越
2
2
2,这样可以凑出
6
6
6 出来。
下面就是具体的代码实现咯,总结一下, st表关键之处是给出适合题目的定义,题目的解可以用
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 用对数时间求出来,然后预处理时要求出
d
p
[
i
]
[
0
]
dp[i][0]
dp[i][0], 要想清楚什么地方步长要从大到小开始遍历。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 10;
const int INF = 0x3f3f3f3f;
const ll mod = 998244353;
int a[maxn], depth[maxn], to[maxn], f[maxn][16], n, q;
vector<int> G[maxn];
void dfs(int x, int fa, int step)
{
int cur = fa; //先预处理出dp[x][0]
for(int j = 15; j >= 0; j--)
if(f[cur][j] && a[f[cur][j]] <= a[x])
cur = f[cur][j];
if(a[cur] > a[x]) //这种情况只会在父节点大于子节点情况发生
f[x][0] = cur;
else
f[x][0] = f[cur][0];
for(int j = 1; j <= 15; j++) //进行动态规划
f[x][j] = f[f[x][j-1]][j-1];
depth[x] = step;
for(auto y: G[x])
if(y != fa)
dfs(y, x, step + 1);
}
int main()
{
scanf("%d %d", &n, &q);
for(int i = 1; i <= n; i++)
scanf("%d", &a[i]);
int x, y;
for(int i = 1; i < n; i++)
{
scanf("%d %d", &x, &y);
G[x].push_back(y);
G[y].push_back(x);
}
for(int i = n + 1; i <= n + q; i++) //给每一个出发结点多一个儿子
{
scanf("%d %d %d", &x, &to[i-n], &a[i]);
G[i].push_back(x);
G[x].push_back(i);
}
dfs(1, 0, 0);
for(int cas = 1; cas <= q; cas++)
{
int ans = 0, x = cas + n, y = to[cas];
for(int i = 15; i >= 0; i--)
if(f[x][i] > 0 && depth[f[x][i]] >= depth[y]) //用深度来判断有没有超过目标结点
{
ans += 1<<i;
x = f[x][i];
}
printf("%d\n", ans);
}
}