Description
一棵树,求经过所有黑点的最短回路。\(n\leqslant 10^5\)
Solution
set DFS序。
一个回路,这个回路可以是所有黑点之间路径上的点作为起点,然后按照DFS序遍历所有点得到。
然后用set维护黑点按DFS序排序后相邻两点之间的距离和即可,注意最后一个点也要和第一个点计算。
Code
/**************************************************************
Problem: 3991
User: BeiYu
Language: C++
Result: Accepted
Time:10800 ms
Memory:92712 kb
****************************************************************/
#include <bits/stdc++.h>
using namespace std;
#define mpr make_pair
#define x first
#define y second
typedef long long LL;
typedef pair<int,LL> pr;
const int N = 200050;
const int M = 22;
inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; }
int n,m,cnt,q;
vector<pr> h[N];
int mk[N];
LL pw2[N],lg2[N];
LL dfn[N],d[N],dis[N],pos[N];
LL f[N][M],g[N][M];
LL ans;
set<pr> S;
void AddEdge(int u,int v,LL w) {
h[u].push_back(mpr(v,w)),h[v].push_back(mpr(u,w));
}
void DFS(int u,int fa,LL dp) {
dfn[++m]=u,pos[u]=m,f[m][0]=u,d[u]=d[fa]+1,dis[u]=dp;
for(int i=0,v;i<(int)h[u].size();i++) if((v=h[u][i].x)!=fa) {
DFS(v,u,dp+h[u][i].y),dfn[++m]=u,f[m][0]=u;
}
}
void init() {
memset(d,0x3f,sizeof(d));
DFS(1,1,0);
pw2[0]=1;for(int i=1;i<M;i++) pw2[i]=pw2[i-1]<<1;
lg2[0]=-1;for(int i=1;i<N;i++) lg2[i]=lg2[i>>1]+1;
for(int j=1;j<M;j++) for(int i=1;i<=m;i++) if(i+pw2[j]-1<=m){
int u=f[i][j-1],v=f[i+pw2[j-1]][j-1];
if(d[u]<d[v]) f[i][j]=u;else f[i][j]=v;
}
}
int LCA(int u,int v) {
if(pos[u]>pos[v]) swap(u,v);
u=pos[u],v=pos[v];
int lg=lg2[v-u+1];
return d[f[u][lg]]<d[f[v-pw2[lg]+1][lg]]?f[u][lg]:f[v-pw2[lg]+1][lg];
}
LL Dis(int u,int v) { return dis[u]+dis[v]-2*dis[LCA(u,v)]; }
void Add(int x) {
if(S.empty()) { S.insert(mpr(pos[x],x));return; }
set<pr>::iterator bf=S.lower_bound(mpr(pos[x],x)),bd=bf;
if(bf==S.begin()) bf=S.end();
bf--;
if(bd==S.end()) bd=S.begin();
ans-=Dis((*bf).y,(*bd).y);
ans+=Dis(x,(*bf).y),ans+=Dis(x,(*bd).y);
S.insert(mpr(pos[x],x));
}
void Del(int x) {
S.erase(mpr(pos[x],x));
if(S.empty()) return;
set<pr>::iterator bf=S.lower_bound(mpr(pos[x],x)),bd=bf;
if(bf==S.begin()) bf=S.end();
bf--;
if(bd==S.end()) bd=S.begin();
ans-=Dis((*bf).y,x),ans-=Dis((*bd).y,x);
ans+=Dis((*bf).y,(*bd).y);
S.erase(mpr(pos[x],x));
}
int main() {
n=in(),q=in();
for(int i=1,u,v,w;i<n;i++) u=in(),v=in(),w=in(),AddEdge(u,v,w);
init();
for(;q--;) {
int x=in();
if(mk[x]) Del(x),mk[x]=0;
else Add(x),mk[x]=1;
printf("%lld\n",ans);
}return 0;
}