算法原理:最近公共祖先 - OI Wiki
例题1: Problem - 2586
下面给出两个示例代码,一个是使用cost[a][i],表示a节点到其第2的i次方个父亲的距离.另外一个是使用dist[a]数组,存储节点a到根节点的距离,两种做法原理相同.
代码1:cost[a][i]做法
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 4e4 + 10;
vector<int> v[N], e[N]; //记录边的下一节点 和 边的权重
int fa[N][31], cost[N][31], dep[N];
// cost[N][i]记录的是 N到其第2的i次方个父节点的距离 不是根节点
int t;
int n, m;
// dfs里面主要是要完成 fa cost dep三个数组的初始化
void dfs(int u, int f) {
//首先处理0位
fa[u][0] = f;
dep[u] = dep[f] + 1;
//然后用爸爸的爸爸是爷爷 处理fa 和 cost数组
for (int i = 1; i < 31; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
cost[u][i] = cost[u][i - 1] + cost[fa[u][i - 1]][i - 1];
}
//然后遍历所有子节点 往下dfs
for (int i = 0; i < v[u].size(); i++) {
int t = v[u][i];
if (t == fa[u][0]) continue;
cost[t][0] = e[u][i]; //记得需要初始处理 t到其父亲的距离
dfs(t, u);
}
}
int lca(int a, int b) {
//让a深度更大
if (dep[a] < dep[b]) swap(a, b);
int ans = 0;
int dif = dep[a] - dep[b];
//然后两个点,根据差值 跳到同一深度
//顺序不重要 反正一定要跳完的
// i需要从0开始 因为0代表的是跳1步
for (int i = 0; dif; i++, dif >>= 1) {
if (dif & 1) {
ans += cost[a][i];
a = fa[a][i];
}
}
if (a == b) return ans;
for (int i = 30; i >= 0; i--) {
//如果这步是可以跳的话,那就跳
//注意不能跳过lca
if (fa[a][i] != fa[b][i]) {
ans += cost[a][i] + cost[b][i];
a = fa[a][i];
b = fa[b][i];
}
}
return ans + cost[a][0] + cost[b][0];
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) {
v[i].clear();
e[i].clear();
}
for (int i = 0; i < n - 1; i++) {
int a, b, c;
scanf("%lld%lld%lld", &a, &b, &c);
v[a].push_back(b);
v[b].push_back(a);
e[a].push_back(c);
e[b].push_back(c);
}
dfs(1, 0);
while (m--) {
int a, b;
scanf("%lld%lld", &a, &b);
cout << lca(a, b) << endl;
}
}
signed main() {
scanf("%lld", &t);
while (t--) solve();
return 0;
}
代码2:dist[a]做法
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 4e4 + 10;
vector<int> v[N], e[N]; //记录边的下一节点 和 边的权重
int fa[N][31], cost[N][31], dep[N];
int dist[N];
// cost[N][i]记录的是 N到其第2的i次方个父节点的距离 不是根节点
int t;
int n, m;
// dfs里面主要是要完成 fa cost dep三个数组的初始化
void dfs(int u, int f, int d) {
//首先处理0位
fa[u][0] = f;
dep[u] = dep[f] + 1;
dist[u] = d;
//然后用爸爸的爸爸是爷爷 处理fa 和 cost数组
for (int i = 1; i < 31; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
// cost[u][i] = cost[u][i - 1] + cost[fa[u][i - 1]][i - 1];
}
//然后遍历所有子节点 往下dfs
for (int i = 0; i < v[u].size(); i++) {
int t = v[u][i];
if (t == fa[u][0]) continue;
// cost[t][0] = e[u][i]; //记得需要初始处理 t到其父亲的距离
dfs(t, u, d + e[u][i]);
}
}
int lca(int a, int b) {
int prea = a, preb = b;
//让a深度更大
if (dep[a] < dep[b]) swap(a, b);
int ans = 0;
int dif = dep[a] - dep[b];
//然后两个点,根据差值 跳到同一深度
//顺序不重要 反正一定要跳完的
// i需要从0开始 因为0代表的是跳1步
for (int i = 0; dif; i++, dif >>= 1) {
if (dif & 1) {
a = fa[a][i];
}
}
if (a == b) return dist[prea] + dist[preb] - 2 * dist[a];
for (int i = 30; i >= 0; i--) {
//如果这步是可以跳的话,那就跳
//注意不能跳过lca
if (fa[a][i] != fa[b][i]) {
// ans += cost[a][i] + cost[b][i];
a = fa[a][i];
b = fa[b][i];
}
}
// return ans + cost[a][0] + cost[b][0];
return dist[prea] + dist[preb] - 2 * dist[fa[a][0]];
}
void solve() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) {
v[i].clear();
e[i].clear();
}
for (int i = 0; i < n - 1; i++) {
int a, b, c;
scanf("%lld%lld%lld", &a, &b, &c);
v[a].push_back(b);
v[b].push_back(a);
e[a].push_back(c);
e[b].push_back(c);
}
dfs(1, 0, 0);
while (m--) {
int a, b;
scanf("%lld%lld", &a, &b);
cout << lca(a, b) << endl;
}
}
signed main() {
scanf("%lld", &t);
while (t--) solve();
return 0;
}
例题二:用户登录
思路:先把整次的总时长求出来,如果删的是第一个点或最后一个点,那就只要删一条边.否则就需要删掉k-1,k与k,k+1这两条边,并且把k-1和k+1连接起来
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e5 + 10;
int n, k;
vector<int> v[N], e[N];
int dep[N], dist[N], a[N], fa[N][31];
void dfs(int u, int f, int d) {
fa[u][0] = f;
dist[u] = d;
dep[u] = dep[f] + 1; //记住深度一定要记录
//然后就是找fa数组
for (int i = 1; i < 31; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}
for (int i = 0; i < v[u].size(); i++) {
int t = v[u][i];
if (t == fa[u][0]) continue;
dfs(t, u, d + e[u][i]);
}
}
int lca(int a, int b) {
int ans = dist[a] + dist[b];
if (dep[a] < dep[b]) swap(a, b);
//求差
int dif = dep[a] - dep[b];
for (int j = 0; dif; j++, dif >>= 1) {
if (dif & 1) a = fa[a][j];
}
if (a == b) return ans - 2 * dist[a];
for (int i = 30; i >= 0; i--) {
if (fa[a][i] != fa[b][i]) {
a = fa[a][i];
b = fa[b][i];
}
}
return ans - 2 * dist[fa[a][0]];
}
signed main() {
scanf("%lld%lld", &n, &k);
for (int i = 0; i < n - 1; i++) {
int a, b, t;
scanf("%lld%lld%lld", &a, &b, &t);
v[a].push_back(b);
v[b].push_back(a);
e[a].push_back(t);
e[b].push_back(t);
}
dfs(1, 0, 0);
int ans = 0;
for (int i = 0; i < k; i++) {
scanf("%d", &a[i]);
if (i) ans += lca(a[i - 1], a[i]);
}
for (int i = 0; i < k; i++) {
int t;
if (i == 0)
t = ans - lca(a[0], a[1]);
else if (i == k - 1)
t = ans - lca(a[k - 2], a[k - 1]);
else
t = ans - lca(a[i - 1], a[i]) - lca(a[i], a[i + 1]) +
lca(a[i - 1], a[i + 1]);
printf("%lld ", t);
}
}
例题三:用户登录
初步思路:预先求好 每一个(a,b)组合的最小公共祖先,我们只要判断,每一个ab组的最小公共祖先 是不是 要删的这条边的祖宗即可.
时间复杂度O(m*logn*n) 由于每次需要找一下祖宗,所以最坏的情况下,可能会超时
PS(验证了一下,想假了实际上这题是树上差分)
正确思路:如果删去边(x,y),a,b会变得不连通,那么这条边一定在a->LCA(a,b)或b->LCA(a,b)上面,我们将这个区间上的标记+1,最后查看那条边的标记为m,其则满足条件
注:需要注意的是,我们区间修改的时候,需要做树上差分,最后再dfssum一遍,让差分数组变成原数组
树上差分教程:前缀和 & 差分 - OI Wiki
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 1e5 + 10;
int n, k;
struct node {
int u, v;
} temp[N];
vector<int> v[N];
int dep[N], a[N], b[N], d[N];
int fa[N][31];
void dfs(int u, int f) {
fa[u][0] = f;
dep[u] = dep[f] + 1; //记住深度一定要记录
//然后就是找fa数组
for (int i = 1; i < 31; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}
for (int i = 0; i < v[u].size(); i++) {
int t = v[u][i];
if (t == fa[u][0]) continue;
dfs(t, u);
}
}
int lca(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
//求差
int dif = dep[a] - dep[b];
for (int j = 0; dif; j++, dif >>= 1) {
if (dif & 1) a = fa[a][j];
}
if (a == b) return a;
for (int i = 30; i >= 0; i--) {
if (fa[a][i] != fa[b][i]) {
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}
void dfsSum(int u, int f) {
for (int x : v[u]) {
if (x == f) continue;
dfsSum(x, u);
d[u] += d[x];
}
}
signed main() {
scanf("%lld%lld", &n, &k);
for (auto i = 0; i < n - 1; i++) {
int a, b;
scanf("%lld%lld", &a, &b);
v[a].push_back(b);
v[b].push_back(a);
temp[i].u = a;
temp[i].v = b;
}
dfs(1, 0);
for (int i = 0; i < k; i++) {
scanf("%lld%lld", &a[i], &b[i]);
int t = lca(a[i], b[i]);
d[a[i]] += 1;
d[b[i]] += 1;
d[t] -= 2;
}
dfsSum(1, 0);
for (int i = n - 2; i >= 0; i--) {
int t = dep[temp[i].u] > dep[temp[i].v] ? temp[i].u : temp[i].v;
if (d[t] == k) {
printf("%lld\n", i + 1);
return 0;
}
}
printf("-1\n");
}