算法描述
根据上一个博客介绍的dfs序以及欧拉序能够把树上的点转为线性的区间点,从而可以用区间的数据结构去维护。根据欧拉序的定义,我们会发现树上任意两点的第一次出现位置之间必然夹带着lca的点,至于为什么可以画图理解一下,因为我们生成这个欧拉序时每次回溯就加一个点,而任意两点之间的搜索树一定是从lca开始往下搜,然后回溯再转而去搜另外一个点,所以lca就生成再两点的时间戳之间了。
于是我们维护完欧拉序后我们可以得到序列中深度最小的那个点必然是lca,这两点之间不会再夹带深度更小的点了,原因和上述蓝字一致。至此,整个问题从树上求LCA转为求区间序列深度最小的点,即RMQ问题,对于这个算法有个
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)预处理,
O
(
1
)
O(1)
O(1)查询的高效算法:ST表(基于动态规划和倍增思想)。这份博客有关于ST表求RMQ的讲解:戳这里。
不过我们的状态要重新设计一下,设
s
t
[
i
]
[
j
]
st[i][j]
st[i][j]表示起点为i,跳
2
j
2^j
2j步长的深度最小的点,这样设计的原因是我们要维护的是深度最小值,但要求的是最小值的那个点,不这样子还要多一次哈希,感觉没啥必要,直接找这个点,状态转移的时候哈希到深度就可以了。(其实就是少了n长度的空间)。其他的和st表的基本操作一致。
实现
#include <bits/stdc++.h>
using namespace std;
const int maxnn = (int)5e5+5;
const int maxnm = (int)1e6+5;
/**
* 利用欧拉序中两点的lca会包含在两个点的in之中的性质
* 查询lca(u, v)相当于查min(in[u], in[v])
* 区间最值查询可以用线段树 理想复杂度是log和倍增的复杂度一样 但常数大
* 这里用st表化为常数级区间最值查询
* 这是个在线算法
*/
int _to[maxnm], _next[maxnm], head[maxnn], cnt;
int Log[maxnm], Mi[21]; //注意这个log的大小,避免越界,例如洛谷越界有时候不会报re而是wa
int in[maxnn], seq[maxnm], deep[maxnn], id;
int st[maxnm][21]; //生成2*n-1的欧拉序起点为i,步长为2^j的序列深度最小的点
int n, m, s;
void edge_add(int u, int v) {
_to[cnt] = v;
_next[cnt] = head[u];
head[u] = cnt++;
}
void init() {
memset(head, -1, sizeof(head));
memset(deep, 0, sizeof(deep));
cnt = id = 0;
int x, y;
for (int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
edge_add(x, y);
edge_add(y, x);
}
}
void dfs(int cur, int parent) {
seq[++id] = cur;
in[cur] = id;
deep[cur]=deep[parent]+1;
for (int i = head[cur]; ~i; i=_next[i]) {
int v = _to[i];
if (v == parent) continue;
dfs(v, cur);
seq[++id] = cur;
}
}
void rmq_init() {
Log[0] = -1;
for (int i = 1; i <= id; i++)
Log[i] = Log[i>>1] + 1; //另外一种递推式:Log[i] = Log[i-1]+(1<<Log[i-1]==i) log(i)+1
Mi[0] = 1;
for (int i = 1; i <= 20; i++)
Mi[i] = Mi[i-1]<<1;
for (int i = 1; i <= id; i++)
st[i][0] = seq[i];
for (int j = 1; j <= Log[id]; j++) {
for (int i = 1; i+Mi[j] <= id+1; i++) {
//st[i][j] = min(st[i][j-1], st[i+Mi[j-1]][j-1])
if (deep[st[i][j-1]] <= deep[st[i+Mi[j-1]][j-1]]) {
st[i][j] = st[i][j-1];
} else {
st[i][j] = st[i+Mi[j-1]][j-1];
}
}
}
}
int query(int u, int v) {
if (u > v) swap(u, v);
int len = v - u + 1;
int k = Log[len];
int dc1 = st[u][k], dc2 = st[v-Mi[k]+1][k];
return deep[dc1] > deep[dc2] ? dc2 : dc1;
}
int main() {
scanf("%d%d%d", &n, &m, &s);
init();
dfs(s, 0);
rmq_init();
int x, y;
for (int i = 0; i < m; i++) {
scanf("%d%d", &x, &y);
printf("%d\n", query(in[x], in[y]));
}
return 0;
}