样例输入
4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1
样例输出
0
100
220
220
280
提示
题解
本题不断增删关键点,而且有一个显而易见的结论:不管从哪个点出发,每条路径都会走两遍。我们可以维护一棵动态虚树,记录虚树中的路径总长。具体实现需要开一个DFS序的set,增删点类比建虚树过程即可。
代码
#include<stdio.h>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<set>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
ll a,b,c,n,m,ans,cnt,Last[maxn];
ll vt,lca,id[maxn],dep[maxn],dfn[maxn],dis[maxn],fa[maxn][20];
set<ll> cache;
struct node
{
ll End,Next,Len;
}edge[2*maxn];
void save(ll x,ll y,ll z)
{
edge[++cnt].End=y,edge[cnt].Len=z;
edge[cnt].Next=Last[x],Last[x]=cnt;
}
void DFS(ll x)
{
dep[x]=dep[fa[x][0]]+1;
dfn[x]=++vt,id[vt]=x;
ll s=ceil(log2(dep[x]));
for(ll i=1;i<=s;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for(ll i=Last[x];i;i=edge[i].Next)
{
ll y=edge[i].End;
if(y==fa[x][0]) continue;
fa[y][0]=x,dis[y]=dis[x]+edge[i].Len;
DFS(y);
}
}
ll getlca(ll x,ll y)
{
if(dep[x]<dep[y]) swap(x,y);
ll k=dep[x]-dep[y],s=ceil(log2(n));
for(ll i=0;i<=s;i++)
if(k&(1<<i)) x=fa[x][i];
if(x==y) return x;
k=ceil(log2(dep[x]));
for(ll i=k;i>=0;i--)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
ll pre(ll x)
{
set<ll>::iterator y=cache.find(dfn[x]);
if(y==cache.begin()) return 0;
return id[*(--y)];
}
ll las(ll x)
{
set<ll>::iterator y=cache.find(dfn[x]);
y++;
if(y==cache.end()) return 0;
return id[*y];
}
void clr(ll x)
{
ll l=pre(x),r=las(x);
if(l) lca=getlca(x,l),ans-=dis[x]+dis[l]-2*dis[lca];
if(r) lca=getlca(x,r),ans-=dis[x]+dis[r]-2*dis[lca];
if(l&&r) lca=getlca(r,l),ans+=dis[r]+dis[l]-2*dis[lca];
cache.erase(dfn[x]);
}
void ins(ll x)
{
cache.insert(dfn[x]);
ll l=pre(x),r=las(x);
if(l) lca=getlca(x,l),ans+=dis[x]+dis[l]-2*dis[lca];
if(r) lca=getlca(x,r),ans+=dis[x]+dis[r]-2*dis[lca];
if(l&&r) lca=getlca(r,l),ans-=dis[r]+dis[l]-2*dis[lca];
}
int main()
{
scanf("%lld%lld",&n,&m);
for(ll i=1;i<n;i++)
{
scanf("%lld%lld%lld",&a,&b,&c);
save(a,b,c),save(b,a,c);
}
DFS(1);
while(m--)
{
scanf("%lld",&a);
if(!cache.size()) cache.insert(dfn[a]),puts("0");
else
{
if(cache.find(dfn[a])!=cache.end()) clr(a);
else ins(a);
lca=getlca(id[*cache.begin()],id[*(--cache.end())]);
printf("%lld\n",ans+dis[id[*cache.begin()]]+dis[id[*(--cache.end())]]-2*dis[lca]);
}
}
return 0;
}