#Description
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
1<=N<=100000
1<=M<=100000
对于全部的数据,1<=z<=10^9
#Solution
今天打noip模拟T3和这个做法好像是类似的,然后我发现这道题之前写过但是没有A,现在终于补上了
答案显然为选中的dfs序相邻点的距离和(绕,考虑用平衡树(set)维护这个东西,我们只需要在插入和删除的时候把贡献改一下即可
#Code
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <set>
#define rep(i,st,ed) for (int i=st;i<=ed;++i)
#define drp(i,st,ed) for (int i=st;i>=ed;--i)
typedef long long LL;
const int N=200005;
const int E=200005;
struct edge {int y,w,next;} e[E];
struct data {
int fi,se;
friend bool operator <(data a,data b) {
return (a.fi==b.fi)?(a.se<b.se):(a.fi<b.fi);
}
};
std:: set <data> set;
std:: set <data>:: iterator pos;
LL dis[N];
int dfn[N],dep[N],fa[N][21];
int ls[N],edCnt;
bool vis[N];
int read() {
int x=0,v=1; char ch=getchar();
for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):(v),ch=getchar());
for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
return x*v;
}
void addEdge(int x,int y,int w) {
e[++edCnt]=(edge) {y,w,ls[x]}; ls[x]=edCnt;
e[++edCnt]=(edge) {x,w,ls[y]}; ls[y]=edCnt;
}
void dfs(int now) {
dfn[now]=++dfn[0];
rep(i,1,20) fa[now][i]=fa[fa[now][i-1]][i-1];
for (int i=ls[now];i;i=e[i].next) {
if (e[i].y==fa[now][0]) continue;
dep[e[i].y]=dep[now]+1;
dis[e[i].y]=dis[now]+e[i].w;
fa[e[i].y][0]=now;
dfs(e[i].y);
}
}
int get_lca(int x,int y) {
if (dep[x]<dep[y]) std:: swap(x,y);
drp(i,20,0) if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
drp(i,20,0) if (fa[x][i]!=fa[y][i]) {
x=fa[x][i];
y=fa[y][i];
}
return fa[x][0];
}
LL get_dis(int x,int y) {
int lca=get_lca(x,y);
return dis[x]+dis[y]-2*dis[lca];
}
int main(void) {
int n,m; scanf("%d%d",&n,&m);
rep(i,2,n) {
int x=read(),y=read(),w=read();
addEdge(x,y,w);
}
dfs(1); LL ans=0;
while (m--) {
int x=read(),a=0,b=0,st=1,ed=1; data now={dfn[x],x};
vis[x]^=1;
if (vis[x]) {
set.insert(now); pos=set.find(now);
if (set.size()==1) {
puts("0");
continue;
}
if (pos!=set.begin()) {
pos--; a=(*pos).se;
ans+=get_dis(x,a);
pos++;
}
pos++;
if (pos!=set.end()) {
b=(*pos).se;
ans+=get_dis(x,b);
}
pos--;
if (a&&b) ans-=get_dis(a,b);
} else {
pos=set.find(now);
if (set.size()==1) {
set.erase(now);
puts("0");
continue;
}
if (pos!=set.begin()) {
pos--; a=(*pos).se;
ans-=get_dis(x,a);
pos++;
}
pos++;
if (pos!=set.end()) {
b=(*pos).se;
ans-=get_dis(x,b);
}
pos--;
if (a&&b) ans+=get_dis(a,b);
set.erase(now);
}
if (set.size()) {
pos=set.begin(); st=(*pos).se;
pos=set.end(); pos--; ed=(*pos).se;
}
printf("%lld\n", ans+get_dis(st,ed));
}
return 0;
}