考虑对所有的关键点构建一个虚树,那么最后结果就是虚树中所有边权之和的
2
倍。但是此题对每次询问的关键点数量之和是没有限制的,因此在最坏情况下这样做的复杂度会达到
继续分析,可以发现如果将每次的关键点按dfs序排序,那么从第
1
个关键点走到第
插入或删除一个关键点,就可以根据这个点在关键点集合中的前驱和后继,计算对答案的影响。
代码:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
inline int read() {
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
typedef long long ll;
const int N = 1e5 + 5, M = 2e5 + 5, L = 7e5 + 5, LogN = 20;
int n, ecnt, nxt[M], adj[N], go[M], val[M], dfn[N], F[N][LogN], times, dep[N];
int m, fa[L], lc[L], rc[L], id[L], RT, QAQ;
ll dis[N], ans; bool in[N];
void add_edge(int u, int v, int w) {
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v; val[ecnt] = w;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u; val[ecnt] = w;
}
void dfs(int u, int fu) {
dfn[u] = ++times; int i; dep[u] = dep[F[u][0] = fu] + 1;
for (i = 0; i <= 17; i++) F[u][i + 1] = F[F[u][i]][i];
for (int e = adj[u], v; e; e = nxt[e])
if ((v = go[e]) != fu) dis[v] = dis[u] + val[e], dfs(v, u);
}
ll dist(int u, int v) {
int i, x = u, y = v; if (dep[u] < dep[v]) swap(u, v), swap(x, y);
for (i = 18; i >= 0; i--) {
if (dep[F[u][i]] >= dep[v]) u = F[u][i];
if (u == v) return dis[x] - dis[u];
}
for (i = 18; i >= 0; i--)
if (F[u][i] != F[v][i]) u = F[u][i], v = F[v][i];
return dis[x] + dis[y] - (dis[F[u][0]] << 1);
}
int which(int x) {return rc[fa[x]] == x;}
void rotate(int x) {
int y = fa[x], z = fa[y], b = lc[y] == x ? rc[x] : lc[x];
if (z) (lc[z] == y ? lc[z] : rc[z]) = x;
fa[x] = z; fa[y] = x; if (b) fa[b] = y;
if (lc[y] == x) rc[x] = y, lc[y] = b;
else lc[x] = y, rc[y] = b;
}
void splay(int x, int tar) {
while (fa[x] != tar) {
while (fa[fa[x]] != tar) {
if (which(x) == which(fa[x])) rotate(fa[x]);
else rotate(x);
}
rotate(x);
}
if (!tar) RT = x;
}
void ins(int v) {
int x = RT, y = 0, pos;
while (x) {
y = x; if (dfn[v] < dfn[id[x]]) pos = 0, x = lc[x];
else pos = 1, x = rc[x];
}
fa[x = ++QAQ] = y; id[x] = v;
if (y) (pos == 0 ? lc[y] : rc[y]) = x;
splay(x, 0);
}
int fin(int v) {
int x = RT;
while (x) {
if (dfn[v] == dfn[id[x]]) return x;
else if (dfn[v] < dfn[id[x]]) x = lc[x];
else x = rc[x];
}
return -1;
}
void join(int x, int y) {
lc[fa[x]] = rc[fa[y]] = 0; int z = y;
while (lc[z]) z = lc[z]; fa[lc[z] = x] = z;
fa[RT = y] = 0; splay(z, 0);
}
void del(int x) {
splay(x, 0); if (!lc[x] || !rc[x])
fa[RT = lc[x] + rc[x]] = 0;
else join(lc[x], rc[x]);
}
void new_node(int x) {
ins(x); if (!lc[RT] && !rc[RT]) return (void) (in[x] = 1);
int pre = lc[RT], suf = rc[RT], fir, lst;
if (pre) while (rc[pre]) pre = rc[pre];
if (suf) while (lc[suf]) suf = lc[suf];
if (pre && suf) {
ans -= dist(id[pre], id[suf]); ans += dist(id[pre], id[RT]);
ans += dist(id[RT], id[suf]);
}
else if (!pre && suf) {
lst = RT; while (rc[lst]) lst = rc[lst];
ans -= dist(id[suf], id[lst]); ans += dist(id[RT], id[suf]);
ans += dist(id[RT], id[lst]);
}
else if (pre && !suf) {
fir = RT; while (lc[fir]) fir = lc[fir];
ans -= dist(id[pre], id[fir]); ans += dist(id[pre], id[RT]);
ans += dist(id[RT], id[fir]);
}
in[x] = 1;
}
void del_node(int x) {
splay(fin(x), 0); if (!lc[RT] && !rc[RT]) {
del(RT); in[x] = 0; return;
}
int pre = lc[RT], suf = rc[RT], fir, lst;
if (pre) while (rc[pre]) pre = rc[pre];
if (suf) while (lc[suf]) suf = lc[suf];
if (pre && suf) {
ans += dist(id[pre], id[suf]); ans -= dist(id[pre], id[RT]);
ans -= dist(id[RT], id[suf]);
}
else if (!pre && suf) {
lst = RT; while (rc[lst]) lst = rc[lst];
ans += dist(id[suf], id[lst]); ans -= dist(id[RT], id[suf]);
ans -= dist(id[RT], id[lst]);
}
else if (pre && !suf) {
fir = RT; while (lc[fir]) fir = lc[fir];
ans += dist(id[pre], id[fir]); ans -= dist(id[pre], id[RT]);
ans -= dist(id[RT], id[fir]);
}
del(RT); in[x] = 0;
}
int main() {
int i, x, y, z; n = read(); m = read();
for (i = 1; i < n; i++) x = read(), y = read(), z = read(),
add_edge(x, y, z);
dfs(1, 0); while (m--) {
x = read(); if (in[x]) del_node(x); else new_node(x);
printf("%lld\n", ans);
}
return 0;
}