题目描述
小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。
小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物
输入输出格式
输入格式:
第一行,两个整数N、M,其中M为宝物的变动次数。接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。
输出格式:
M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。
输入输出样例
输入样例#1:
4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1
输出样例#1:
0
100
220
220
280
说明
1<=N<=100000
1<=M<=100000
对于全部的数据,1<=z<=10^9
//set维护 DFS序 学号set真的好牛逼的~~
#include<iostream>
#include<cstring>
#include<cstdio>
#include<set>
using namespace std;
const int MAXN = 100010;
const int LogN = 25;
#define LL long long
int fa[MAXN][LogN],dep[MAXN],w[MAXN],vis[MAXN],visx,v[MAXN];
LL dis[MAXN],ans;
set<int> s;
set<int>:: iterator it;
struct Edge{ int to,next; LL w; }e[MAXN*2];
int head[MAXN],tot,n,m;
inline void Add_Edge(int u,int v,LL w){
e[++tot].to=v;e[tot].w=w;
e[tot].next=head[u];head[u]=tot;
}
inline int Get_LCA(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=20;i>=0;i--)
if((dep[u]-dep[v])&(1<<i))
u=fa[u][i];
if(u==v) return v;
for(int i=20;i>=0;i--)
if(fa[u][i]!=fa[v][i]){
u=fa[u][i];v=fa[v][i];
}
return fa[u][0];
}
inline void DFS(int u,int father,int deepth){
dep[u]=deepth;fa[u][0]=father;vis[u]= ++visx;w[vis[u]]=u;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==father) continue;
dis[v]=dis[u]+e[i].w;
DFS(v,u,deepth+1);
}
}
inline void DP(){
for(int j=1;j<=20;j++)
for(int i=1;i<=n;i++)
fa[i][j]=fa[fa[i][j-1]][j-1];
}
LL Distance(int x,int y){
return dis[x]+dis[y]-dis[Get_LCA(x,y)]*2;
}
int netation(int x){
it=s.find(vis[x]);
return ++it==s.end() ? 0 : w[*it];
}
int prepare(int x){
it=s.find(vis[x]);
return it==s.begin() ? 0 : w[*--it];
}
void Solve_1(int x){// Erase
int l=prepare(x),r=netation(x);
if(l) ans-=Distance(l,x);
if(r) ans-=Distance(x,r);
if(l&&r) ans+=Distance(l,r);
s.erase(vis[x]);
}
void Solve_2(int x){//Add
s.insert(vis[x]);
int l=prepare(x),r=netation(x);
if(l) ans+=Distance(l,x);
if(r) ans+=Distance(x,r);
if(l&&r) ans-=Distance(l,r);
}
int main(){
scanf("%d%d",&n,&m);
for(int x,y,i=1;i<=n-1;i++){
LL w;
scanf("%d%d%lld",&x,&y,&w);
Add_Edge(x,y,w);Add_Edge(y,x,w);
}
DFS(1,0,1);DP();ans=0;
for(int x,i=1;i<=m;i++){
scanf("%d",&x);
if(v[x]) Solve_1(x);
else Solve_2(x);
v[x]^=1;//取反操作
printf("%lld\n",s.size() ? ans+Distance(w[*s.begin()],w[*--s.end()]) : 0);
}
return 0;
}