题意:一颗n节点的带权无向树,一个初始为空的集合S,有两种操作:1 x 如果S中没有x,加入x ;2 x 如果S中存在x,删除x。每次操作完后求出使S集合连通的最小边权和。
解析:设root为S集合的LCA。点i的左右时间戳为L[i],R[i],ans为当前的权重和。
当前集合S中添加节点x,如果x不在root的子树中,那么ans += dis[x]+dis[root]-2*dis[lca(x,root)]。如果x在root子树中,那么向上找到离x最近的且子树中包含集合S中某些点的节点u,ans += dis[x] - dis[u];
当前集合S中删除节点x,删除后的集合为S',如果lca(S') != root,那么ans -= dis[x] + dis[lca(S')] - 2*dis[root]。如果lca(S') = root,那么向上找到离x最近的且子树中包含集合S’中某些点的节点u,ans -= dis[x] - dis[u];
[code]:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<functional>
#include<queue>
#include<set>
#define lowbit(i) (i&-i)
using namespace std;
typedef pair<int,int> P;
const int maxn = 1e5+5;
inline bool read(int &ret){
char c;int sgn;
if(c = getchar(),c==EOF) return 0;
while(c != '-'&&(c<'0'||c>'9')) c=getchar();
sgn = (c=='-')?-1:1;
ret = (c=='-')?0:(c-'0');
while(c=getchar(),c>='0'&&c<='9') ret=ret*10+(c-'0');
ret*=sgn;
return 1;
}
struct Nod{
int b,val,next;
void init(int b,int val,int next){
this->b=b;this->val=val;this->next=next;
}
}buf[maxn<<1];
int n,m,len,E[maxn],ans;
int L[maxn],R[maxn],tot,root;
int par[maxn][20],dis[maxn],dep[maxn];
int bit[maxn<<1];
set<int> ms;
void add(int k,int x){
for(;k<=tot;k+=lowbit(k)) bit[k]+=x;
}
int sum(int k){
int ans = 0;
for(;k;k-=lowbit(k)) ans += bit[k];
return ans;
}
void init(){
ms.clear();
root = -1;
len = tot = ans = 0;
memset(E,-1,n*sizeof(int));
memset(bit,0,(2*n+1)*sizeof(int));
}
void add_edge(int a,int b,int c){
buf[len].init(b,c,E[a]);E[a]=len++;
buf[len].init(a,c,E[b]);E[b]=len++;
}
void dfs(int u,int pre,int deep){
int i,v;
par[u][0] = pre;dep[u] = deep;
L[u] = ++tot;
for(i = E[u];i != -1;i = buf[i].next){
v = buf[i].b;
if(v == pre)continue;
dis[v] = dis[u] + buf[i].val;
dfs(v,u,deep+1);
}
R[u] = ++tot;
}
void preprocess(){
int i,j;
dfs(0,-1,0);
for(i = 1;i < 20;i++){
for(j = 0;j < n;j++){
par[j][i] = par[j][i-1]!=-1?par[par[j][i-1]][i-1]:-1;
}
}
}
int GetVet(int u,int len){
int i;
for(i = 19;i >= 0;i--){
if(len>=(1<<i)){
u = par[u][i];
len -= (1<<i);
if(u == -1) return -1;
}
}
return u;
}
void AddOp(int u){
int i,j,v;
if(root==-1){
root = u;
ans = 0;
ms.insert(u);
add(L[u],1);
return;
}
int lb,rb,mid,vv;
lb = -1,rb = n+1;
while(rb-lb>1){
mid = (lb+rb)>>1;
vv = GetVet(u,mid);
if(vv==-1||sum(R[vv])-sum(L[vv]-1)>0){
rb = mid;
v = vv;
}else lb = mid;
}
if(L[u]>=L[root]&&L[u]<=R[root]){
ans += dis[u]-dis[v];
}else{
ans += dis[root]+dis[u]-2*dis[v];
root = v;
}
add(L[u],1);
ms.insert(u);
}
void DelOp(int u){
int i,j,v;
add(L[u],-1);
ms.erase(u);
if(ms.size()==1){
root = -1;
ans = 0;
//ms.erase(u);
return;
}
int lb,rb,mid,vv;
lb = -1,rb = n+1;
while(rb-lb>1){
mid = (lb+rb)>>1;
vv = GetVet(*ms.begin(),mid);
if(vv==-1||sum(R[vv])-sum(L[vv]-1)==ms.size()){
rb = mid;
v = vv;
}else lb = mid;
}
if(v==root){
lb = -1,rb = n+1;
while(rb-lb>1){
mid = (lb+rb)>>1;
vv = GetVet(u,mid);
if(vv==-1||sum(R[vv])-sum(L[vv]-1)>0){
rb = mid;
v = vv;
}else lb = mid;
}
ans -= dis[u]-dis[v];
}else{
ans -= dis[v]+dis[u]-2*dis[root];
root = v;
}
}
int main(){
int i,j,cas,u,v,w,op;
scanf("%d",&cas);
for(int T=1;T<=cas;T++){
scanf("%d%d",&n,&m);
init();
for(i = 1;i < n;i++){
//scanf("%d%d%d",&u,&v,&w);
read(u);read(v);read(w);
u--,v--;
add_edge(u,v,w);
}
preprocess();
printf("Case #%d:\n",T);
while(m--){
scanf("%d%d",&op,&u);u--;
if(op == 1&&ms.find(u)==ms.end()) AddOp(u);
else if(op==2&&ms.find(u)!=ms.end()) DelOp(u);
printf("%d\n",ans);
}
}
return 0;
}