题目:http://acm.hdu.edu.cn/showproblem.php?pid=5296
题意:给定N个点和N-1条边以及边的权值,再给定一个set,向set里面添加/删除节点求让set里面的节点连通的最小费用。
分析:看大神博客知道的。这题主要求点到链的距离。先解决点到点的距离,假如dis[x]表示x到root的距离,那么点u和v的距离就是dis[u]+dis[v]-dis[lca(u,v)]。求u到链x~y的距离:令u到点x的距离为d1,u到点y的距离为d2,x到y的距离为d3,那么u到链x~y的距离就是(d1+d2-d3)/2。怎么找怎么链x~y?用一个set维护即可。
代码:
#include <iostream>
#include <cstdio>
#include <vector>
#include <set>
using namespace std;
const int maxn = 1e6+6;
struct node
{
int w,o;
node(){}
node(int a,int b)
:o(a),w(b){}
};
int N,M,p[maxn],fp[maxn],anc[maxn][20],deep[maxn],dis[maxn],cur,visit[maxn];
vector <node > vct[maxn];
set <int > st;
set <int >::iterator it,it1,it2;
void Init()
{
int i,j;
st.clear();
for(i=0;i<=N;i++)
{
vct[i].clear();
visit[i]=0;
}
for(i=1;i<=N;i++)
for(j=0;j<=17;j++)
anc[i][j]=1;
}
void Create(int root,int d,int dist)
{
deep[root]=d;
dis[root]=dist;
for(int i=0;i<vct[root].size();i++)
{
int son=vct[root][i].o;
if(anc[root][0]!=son)
{
anc[son][0]=root;
for(int j=1;j<=17;j++)
anc[son][j]=anc[anc[son][j-1]][j-1];
Create(son,d+1,dist+vct[root][i].w);
}
}
}
int Jump(int a,int x)
{
for(int i=0;i<=17;i++)
if(x&(1<<i))
a=anc[a][i];
return a;
}
int LCA(int u,int v)
{
if(deep[u]<deep[v])
swap(u,v);
u=Jump(u,deep[u]-deep[v]);
if(u==v)
return u;
for(int i=17;i>=0;i--)
if(anc[u][i]!=anc[v][i])
{
u=anc[u][i];
v=anc[v][i];
}
return anc[u][0];
}
void toNum(int root)
{
p[root]=++cur;
fp[cur]=root;
for(int i=0;i<vct[root].size();i++)
{
int son=vct[root][i].o;
if(anc[root][0]!=son)
toNum(son);
}
}
int main()
{
int ncase,i,j,u,v,w,tp,x,y,fx,fy,ans;
scanf("%d",&ncase);
for(int T=1;T<=ncase;T++)
{
scanf("%d%d",&N,&M);
Init();
for(i=1;i<N;i++)
{
scanf("%d%d%d",&u,&v,&w);
vct[u].push_back(node(v,w));
vct[v].push_back(node(u,w));
}
cur=0;
Create(1,0,0);
toNum(1);
ans=0;
printf("Case #%d:\n",T);
while(M--)
{
scanf("%d%d",&tp,&u);
if(tp==1)
{
if(!visit[u])
{
st.insert(p[u]);
visit[u]=1;
if(st.find(p[u])!=st.begin())
it1=--st.find(p[u]);
else
it1=st.end();
if(st.find(p[u])!=st.end())
it2=++st.find(p[u]);
else
it2=st.end();
if(it1==st.end() || it2==st.end())
{
st.erase(st.find(p[u]));
if(st.size()<1)
puts("0");
else
{
x=*st.begin();
y=*(--st.end());
fx=fp[x];
fy=fp[y];
ans+=dis[u]-dis[LCA(u,fx)]-dis[LCA(u,fy)]+dis[LCA(fx,fy)];
printf("%d\n",ans);
}
st.insert(p[u]);
}
else
{
x=*it1;
y=*it2;
fx=fp[x];
fy=fp[y];
ans+=dis[u]-dis[LCA(u,fx)]-dis[LCA(u,fy)]+dis[LCA(fx,fy)];
printf("%d\n",ans);
}
}
else
printf("%d\n",ans);
}
else
{
if(visit[u])
{
if(st.find(p[u])!=st.begin())
it1=--st.find(p[u]);
else
it1=st.end();
if(st.find(p[u])!=st.end())
it2=++st.find(p[u]);
else
it2=st.end();
if(it1==st.end() || it2==st.end())
{
st.erase(st.find(p[u]));
if(st.size()<1)
puts("0");
else
{
x=*st.begin();
y=*(--st.end());
fx=fp[x];
fy=fp[y];
ans-=dis[u]-dis[LCA(u,fx)]-dis[LCA(u,fy)]+dis[LCA(fx,fy)];
printf("%d\n",ans);
}
st.insert(p[u]);
}
else
{
x=*it1;
y=*it2;
fx=fp[x];
fy=fp[y];
ans-=dis[u]-dis[LCA(u,fx)]-dis[LCA(u,fy)]+dis[LCA(fx,fy)];
printf("%d\n",ans);
}
st.erase(st.find(p[u]));
}
else
printf("%d\n",ans);
visit[u]=0;
}
}
}
return 0;
}