题意
给定两棵树\(T, T'\),求
\[ \max_{(x, y)} dep_x + dep_y - (dep_{lca(x, y)} + dep'_{lca'(x, y)}) \]
题解
经过一些化简后就变成了求
\[ \max_{(x, y)} \frac{1}{2} dep_x + \frac{1}{2} dep_y + \frac{1}{2} dis(x, y) - dep'_{lca'(x, y)} \]
有两种主流做法。
一种是边分+虚树+dp。
这种还没实现过。
另一种是边分+点分树合并。
考虑对第一棵树进行边分,设\(d_x\)为\(x\)到当前边分中心的距离,则求
\[ \max_{(x, y)} \frac{1}{2} (dep_x + d_x) + \frac{1}{2} (dep_y + d_x) - dep'_{lca'(x, y)} \]
在一次边分中,前半部分是容易求得的。
考虑固定后面部分,即枚举\(x, y\)在第二棵树上的lca。
枚举了lca,要保证\(x, y\)在lca的子树中,怎么办呢?
考虑用对于第二棵树的每一个点建一类棵动态开点线段树,每棵这种树维护在第一棵树中经过该点在不同的点分子树中的贡献(要记录在不同的点分中心的左右两侧的产生的贡献,贡献的形式就是\(dep_x + d_x\)),然后在第二棵树上从下往上合并即可,合并的时候对于两个点在同一棵点分子树的同一个部分(左/右)的贡献取max。并且左边和左边合并,右边和右边合并,左边只能和右边产生贡献,右边只能和左边产生贡献,就和线段树合并类似了。至于复杂度,大概就是根据边分树的优越结构的性质(完全二叉树),就很优秀就是了。
大概是\(\mathcal O(n \log n)\)的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 750005;
struct tr {
int tot, lnk[N], nxt[N << 1], son[N << 1], w[N << 1];
tr () {
tot = 1;
memset(lnk, 0, sizeof lnk);
memset(nxt, 0, sizeof nxt);
}
void add (int x, int y, int z) {
nxt[++tot] = lnk[x], lnk[x] = tot, son[tot] = y, w[tot] = z;
}
void adds (int x, int y, int z) {
add(x, y, z), add(y, x, z);
}
} T0, T1, T2;
int n, m, s, e, mn, tot; ll ans;
int sz[N], *pos[N], rt[N], las[N], lc[N * 20], rc[N * 20];
ll dep[N], mx[N * 20][2];
bool ban[N << 1];
void append (int x, int y, int z) {
T1.adds(las[x], ++m, 0), las[x] = m;
T1.adds(m, y, z);
}
void build (int x, int p, ll d) {
dep[x] = d;
for (int j = T0.lnk[x]; j; j = T0.nxt[j]) {
if (T0.son[j] != p) {
append(x, T0.son[j], T0.w[j]);
build(T0.son[j], x, d + T0.w[j]);
}
}
}
void gete (int x, int p) {
sz[x] = 1;
for (int j = T1.lnk[x], v; j; j = T1.nxt[j]) {
if (!ban[j >> 1] && T1.son[j] != p) {
gete(T1.son[j], x);
sz[x] += sz[T1.son[j]];
v = max(sz[T1.son[j]], s - sz[T1.son[j]]);
if (mn > v) {
e = j, mn = v;
}
}
}
}
void dfs (int x, int p, ll d, bool f) {
if (x <= n) {
*pos[x] = ++tot;
mx[tot][f] = d + dep[x];
mx[tot][f ^ 1] = -1e18;
pos[x] = (f ? &rc[tot] : &lc[tot]);
}
for (int j = T1.lnk[x]; j; j = T1.nxt[j]) {
if (!ban[j >> 1] && T1.son[j] != p) {
dfs(T1.son[j], x, d + T1.w[j], f);
}
}
}
void dac (int x, int curs) {
if (curs > 1) {
s = curs, e = -1, mn = 1e9;
gete(x, 0);
ban[e >> 1] = 1;
int u = T1.son[e ^ 1], v = T1.son[e], su = s - sz[v], sv = sz[v];
dfs(u, 0, 0, 0), dfs(v, 0, T1.w[e], 1);
dac(u, su), dac(v, sv);
}
}
int merge (int x, int y, ll d) {
if (!x || !y) {
return x | y;
}
ans = max(ans, (mx[x][0] + mx[y][1]) / 2 - d);
ans = max(ans, (mx[y][0] + mx[x][1]) / 2 - d);
mx[x][0] = max(mx[x][0], mx[y][0]);
mx[x][1] = max(mx[x][1], mx[y][1]);
lc[x] = merge(lc[x], lc[y], d);
rc[x] = merge(rc[x], rc[y], d);
return x;
}
void solve (int x, int p, ll d) {
ans = max(ans, dep[x] - d);
for (int j = T2.lnk[x]; j; j = T2.nxt[j]) {
if (T2.son[j] != p) {
solve(T2.son[j], x, d + T2.w[j]);
rt[x] = merge(rt[x], rt[T2.son[j]], d);
}
}
}
int main () {
scanf("%d", &n);
for (int i = 1, x, y, z; i < n; ++i) {
scanf("%d%d%d", &x, &y, &z);
T0.adds(x, y, z);
}
for (int i = 1; i <= n; ++i) {
las[i] = i, pos[i] = &rt[i];
}
m = n;
build(1, 0, 0);
dac(1, m);
for (int i = 1, x, y, z; i < n; ++i) {
scanf("%d%d%d", &x, &y, &z);
T2.adds(x, y, z);
}
solve(1, 0, 0);
printf("%lld\n", ans);
return 0;
}