题目链接
解题思路
神仙题。(真难写)
首先这组成的是一棵树。我们假定1为根。
然后,我们考虑使用dfs序来存下这颗树。这样有什么好处呢?我们把树上问题转换为了链上问题。(注意,这个dfs序要是个环)
我们动手画画图,我们就会发现,我们可以根据这个dfs序来计算每次加入、减少的贡献值。
使用平衡树来维护异象石序列,按照dfs序排序(注意首尾相连)
记当前节点为
p
p
p
那么我们会发现:每加入一个节点,它对于答案的贡献是
a
n
s
+
=
d
i
s
t
(
p
r
e
,
p
)
+
d
i
s
t
(
p
,
n
x
t
)
−
d
i
s
t
(
p
r
e
,
n
x
t
)
ans+=dist(pre,p)+dist(p,nxt)-dist(pre,nxt)
ans+=dist(pre,p)+dist(p,nxt)−dist(pre,nxt)
其中
d
i
s
t
dist
dist为树上距离,
p
r
e
pre
pre为前驱,
n
x
t
nxt
nxt为后继。
相应地,删除一个节点的贡献为
a
n
s
−
=
d
i
s
t
(
p
r
e
,
p
)
+
d
i
s
t
(
p
,
n
x
t
)
−
d
i
s
t
(
p
r
e
,
n
x
t
)
ans-=dist(pre,p)+dist(p,nxt)-dist(pre,nxt)
ans−=dist(pre,p)+dist(p,nxt)−dist(pre,nxt)
注意: 这样计算出的答案是题目所求的两倍,故询问?
要除以2。
其实这题不用手写平衡树(qwq),我们可以使用map啊(pascal选手:我究竟犯了什么错2333)
详细代码
#define USEFASTERREAD 1
#define rg register
#define inl inline
#define DEBUG printf("qwq\n")
#define DEBUGd(x) printf("var %s is %lld", #x, ll(x))
#define DEBUGf(x) printf("var %s is %llf", #x, double(x))
#define putln putchar('\n')
#define putsp putchar(' ')
#define Rep(a, s, t) for(rg int a = s; a <= t; a++)
#define Repdown(a, t, s) for(rg int a = t; a >= s; a--)
typedef long long ll;
typedef unsigned long long ull;
#include<cstdio>
#if USEFASTERREAD
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), ss == tt) ? EOF : *ss++)
#endif
namespace IO {
inl void RS() {freopen("test.in", "r", stdin), freopen("test.out", "w", stdout);}
inl ll read() {
ll x = 0, f = 1; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
return x * f;
}
inl char readc() {
char ch = getchar();
while(ch != '+' && ch != '-' && ch != '?') ch = getchar();
return ch;
}
inl void write(ll x) {
if(x < 0) {putchar('-'); x = -x;}
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
inl void writeln(ll x) {write(x), putln;}
inl void writesp(ll x) {write(x), putsp;}
}
using namespace IO;
template<typename T> inline T Max(const T& x, const T& y) {return y < x ? x : y;}
template<typename T> inline T Min(const T& x, const T& y) {return y < x ? y : x;}
template<typename T> inline void Swap(T& x, T& y) {T tmp = x; x = y; y = tmp;}
template<typename T> inline T Abs(const T& x) {return x < 0 ? -x : x;}
#include<set>
using std::set;
const int MAXN = 1e5 + 5;
int N, M;
struct Edge {
int v, nxt;
ll w;
}e[MAXN << 1];
int head[MAXN], cnt;
void addedge(int u, int v, ll w) {
e[++cnt].v = v;
e[cnt].w = w;
e[cnt].nxt = head[u];
head[u] = cnt;
}
int lg[MAXN];
int dep[MAXN];
int fa[MAXN][25];
ll dis[MAXN]; //to root node
void dfs(int u, int f, ll w) {
dep[u] = dep[f] + 1;
fa[u][0] = f;
dis[u] = dis[f] + w;
for(rg int i = 1; (1 << i) <= dep[u]; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(rg int i = head[u]; i; i = e[i].nxt)
if(e[i].v != f)
dfs(e[i].v, u, e[i].w);
}
int lca(int x, int y) {
if(dep[x] < dep[y]) Swap(x, y);
while(dep[x] > dep[y])
x = fa[x][lg[dep[x] - dep[y]]];
if(x == y) return x;
for(rg int i = lg[dep[x]]; i >= 0; i--)
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
ll dist(int x, int y) {
int l = lca(x, y);
return dis[x] + dis[y] - dis[l] * 2;
}
int pth[MAXN];
int dfn[MAXN], tim;
void dfs(int u, int f) {
dfn[u] = ++tim;
pth[tim] = u;
for(rg int i = head[u]; i; i = e[i].nxt)
if(e[i].v != f) dfs(e[i].v, u);
}
set<int> mp;
typedef set<int>::iterator iter;
ll ans;
int main() {
//RS();
N = read();
for(rg int i = 1; i < N; i++) {
int x = read(), y = read(), z = read();
addedge(x, y, z);
addedge(y, x, z);
}
lg[0] = -1;
for(rg int i = 1; i < MAXN; i++) lg[i] = lg[i >> 1] + 1;
dfs(1, 0, 0);
dfs(1, 0);
M = read();
for(rg int i = 1; i <= M; i++) {
char opt = readc();
if(opt == '+') {
int x = read();
mp.insert(dfn[x]);
if(mp.empty()) continue;
iter p = mp.find(dfn[x]);
iter pre = p;
if(pre == mp.begin()) pre = mp.end(), pre--;
else pre--;
iter nxt = p;
nxt++; if(nxt == mp.end()) nxt = mp.begin();
ll ans1 = dist(pth[*pre], pth[*nxt]);
ll ans2 = dist(pth[*pre], pth[*p]);
ll ans3 = dist(pth[*p], pth[*nxt]);
ans += -ans1 + ans2 + ans3;
} else if(opt == '-') {
int x = read();
iter p = mp.find(dfn[x]);
iter pre = p;
if(pre == mp.begin()) pre = mp.end(), pre--;
else pre--;
iter nxt = p;
nxt++; if(nxt == mp.end()) nxt = mp.begin();
ll ans1 = dist(pth[*pre], pth[*nxt]);
ll ans2 = dist(pth[*pre], pth[*p]);
ll ans3 = dist(pth[*p], pth[*nxt]);
ans += ans1 - ans2 - ans3;
mp.erase(p);
} else {
writeln(ans / 2);
}
}
return 0;
}