【题目链接】
【思路要点】
- 考虑计算 2 d e p t h ( x ) + 2 d e p t h ( y ) − 2 ( d e p t h ( L c a ( x , y ) ) − d e p t h ′ ( L c a ′ ( x , y ) ) ) 2depth(x)+2depth(y)-2(depth(Lca(x,y))-depth'(Lca'(x,y))) 2depth(x)+2depth(y)−2(depth(Lca(x,y))−depth′(Lca′(x,y))) ,即 d e p t h ( x ) + d e p t h ( y ) + d i s t ( x , y ) − 2 d e p t h ′ ( L c a ′ ( x , y ) ) depth(x)+depth(y)+dist(x,y)-2depth'(Lca'(x,y)) depth(x)+depth(y)+dist(x,y)−2depth′(Lca′(x,y)) 的最大值。
- 枚举 L c a ′ ( x , y ) Lca'(x,y) Lca′(x,y) ,我们需要维护一些集合,支持:
( 1 ) (1) (1) 、建立一个包含一个点的集合。
( 2 ) (2) (2) 、查询在两个集合中各选一点, d e p t h ( x ) + d e p t h ( y ) + d i s t ( x , y ) depth(x)+depth(y)+dist(x,y) depth(x)+depth(y)+dist(x,y) 的最大值,并合并这两个集合。- 对第一棵树构建边分结构,并为每一个点建立集合,注意到边分结构类似于线段树的结构,深度为 O ( L o g N ) O(LogN) O(LogN) ,且为二叉树,可以用线段树合并来解决操作 ( 2 ) (2) (2) 。
- 时间复杂度 O ( N L o g N ) O(NLogN) O(NLogN) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 8e5 + 5; const int MAXP = 2e7 + 5; const long long INF = 1e18; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct Node { int lc, rc; ll lMax, rMax; } t[MAXP]; int tsize; ll ans, depth[MAXN]; int n, m, root[MAXN], rootx, rooty, len; int curcol, size[MAXN], col[MAXN], dir[MAXN]; vector <ll> dist[MAXN]; vector <pair <int, int>> a[MAXN], b[MAXN], c[MAXN]; int merge(int x, int y, ll add) { if (x == 0 || y == 0) return x + y; chkmax(ans, t[x].lMax + t[y].rMax + add); chkmax(ans, t[x].rMax + t[y].lMax + add); chkmax(t[x].lMax, t[y].lMax); chkmax(t[x].rMax, t[y].rMax); t[x].lc = merge(t[x].lc, t[y].lc, add); t[x].rc = merge(t[x].rc, t[y].rc, add); return x; } void getans(int pos, int fa, ll len) { chkmax(ans, 2 * depth[pos] - len); for (auto x : b[pos]) if (x.first != fa) { getans(x.first, pos, len + 2ll * x.second); root[pos] = merge(root[pos], root[x.first], -len); } } void addedge(int x, int y, int len) { a[x].emplace_back(y, len); a[y].emplace_back(x, len); } void findroot(int pos, int fa, int cur, int tot) { size[pos] = 1; for (auto x : a[pos]) if (x.first != fa && col[x.first] == cur) { findroot(x.first, pos, cur, tot); size[pos] += size[x.first]; } if (abs(2 * size[pos] - tot) < abs(2 * size[rootx] - tot)) { rootx = pos; rooty = fa; } } void dfs(int pos, int fa, int cur, int ncol, ll len) { col[pos] = ncol; if (pos <= n) dist[pos].push_back(len + depth[pos]); for (auto x : a[pos]) if (x.first != fa && col[x.first] == cur) { dfs(x.first, pos, cur, ncol, len + x.second); } } void modify(int &root, int pos, int now, int depth) { root = ++tsize; t[root].lMax = t[root].rMax = -INF; if (now == depth) return; if (dir[now]) { t[root].lMax = dist[pos][now]; modify(t[root].lc, pos, now + 1, depth); } else { t[root].rMax = dist[pos][now]; modify(t[root].rc, pos, now + 1, depth); } } void buildst(int pos, int cur, int dep, int tot) { rootx = rooty = 0; findroot(pos, 0, cur, tot); if (size[pos] == 1) { if (pos <= n) modify(root[pos], pos, 0, dep); return; } for (auto x : a[rootx]) if (x.first == rooty) len = x.second; dfs(rootx, rooty, cur, ++curcol, len); dfs(rooty, rootx, cur, ++curcol, 0); int tmp = rootx, tnp = rooty, tsize = size[rootx]; dir[dep] = false, buildst(tmp, col[tmp], dep + 1, tsize); dir[dep] = true, buildst(tnp, col[tnp], dep + 1, tot - tsize); } void rebuild(int pos, int fa) { int last = 0; for (auto x : c[pos]) if (x.first != fa) { depth[x.first] = depth[pos] + x.second; rebuild(x.first, pos); if (last == 0) { addedge(pos, x.first, x.second); last = pos; } else { addedge(last, ++m, 0); addedge(m, x.first, x.second); last = m; } } } int main() { read(n), m = n; for (int i = 1; i <= n - 1; i++) { int x, y, z; read(x), read(y), read(z); c[x].emplace_back(y, z); c[y].emplace_back(x, z); } for (int i = 1; i <= n - 1; i++) { int x, y, z; read(x), read(y), read(z); b[x].emplace_back(y, z); b[y].emplace_back(x, z); } rebuild(1, 0); buildst(1, 0, 0, m); getans(1, 0, 0); writeln(ans / 2); return 0; }