Description
小Y有一棵n个节点的树,每条边都有正的边权。
小J有q个询问,每次小J会删掉这个树中的k条边,这棵树被分成k+1个连通块。小J想知道每个连通块中最远点对距离的和。
这里的询问是互相独立的,即每次都是在小Y的原树上进行操作。
n,q,∑k<=100000,边权<=109 n , q , ∑ k <= 100000 , 边 权 <= 10 9
Solution
删掉一条边,相当于将一个子树割裂出来,体现在DFS序上就是删掉了一段区间。
这样一个联通块就由DFS序上若干个区间组成。
我们将这些删边操作排序放到DFS序上,跑一次,可以求出每个联通块由哪些区间组成,线段树维护DFS序区间直径即可(记录直径端点,合并的时候四个点组合看哪个最大)。
求点对距离的话,RMQ做LCA,每次询问是O(1)的
由于每次删边最多只会将一个区间分成三个,因此总查询区间数仍是O(K)的
那么总复杂度就是 O(nlogn+∑K(logn+logK)) O ( n log n + ∑ K ( log n + log K ) )
Code
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 200005
#define M 400005
#define LL long long
using namespace std;
int pt[M],dep[N],sz[N],dfn[N],n1,n,m1,q,ask[M],rmq[M][18],cf[18],dft[N],wt[M],fs[N],nt[M],dt[M],pr[M],l2[N],dfw[N],t[M][2],fx[M],st[M];
LL dis[M];
int m2,nx[N],fi[N];
struct node
{
int x,y;
node(int _x=0,int _y=0):x(_x),y(_y){};
friend bool operator <(node x,node y)
{
return (x.x<y.x||(x.x==y.x&&x.y<y.y));
}
}dx[M],d1[M];
void link(int x,int y,int w)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
pr[m1]=w;
}
void lk(int x,int y)
{
nx[y]=fi[x];
fi[x]=y;
}
int dmin(int x,int y)
{
return ((dep[x]<dep[y])?x:y);
}
void dfs(int k,int fa)
{
dfw[dfn[k]=++dfn[0]]=k;
dft[wt[k]=++dft[0]]=k;
dep[k]=dep[fa]+1;
sz[k]=1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa) pt[i]=pt[fx[i]]=p,dis[p]=dis[k]+pr[i],dfs(p,k),sz[k]+=sz[p],dft[++dft[0]]=k;
}
}
int lca(int x,int y)
{
int l=wt[x],r=wt[y];
if(l>r) swap(l,r);
int p=l2[r-l+1];
return dmin(rmq[l][p],rmq[r-cf[p]+1][p]);
}
LL ds(node &a)
{
return dis[a.x]+dis[a.y]-(LL)2*dis[lca(a.x,a.y)];
}
node dmax(node a,node b)
{
return (ds(a)>ds(b))?(a):(b);
}
node merge(node a,node b)
{
return dmax(dmax(a,b),dmax(dmax(node(a.x,b.x),node(a.x,b.y)),dmax(node(a.y,b.x),node(a.y,b.y))));
}
void build(int k,int l,int r)
{
if(l==r) dx[k].x=dx[k].y=dfw[l];
else
{
int mid=(l+r)>>1;
t[k][0]=++n1,build(n1,l,mid);
t[k][1]=++n1,build(n1,mid+1,r);
dx[k]=merge(dx[t[k][0]],dx[t[k][1]]);
}
}
node query(int k,int l,int r,int x,int y)
{
if(l==x&&r==y) return dx[k];
int mid=(l+r)>>1;
if(y<=mid) return query(t[k][0],l,mid,x,y);
else if(x>mid) return query(t[k][1],mid+1,r,x,y);
else return merge(query(t[k][0],l,mid,x,mid),query(t[k][1],mid+1,r,mid+1,y));
}
LL get(int l,int r)
{
LL s=0;
int ls=r;
node p=node(dfw[l],dfw[l]);
for(int i=fi[l];i;i=nx[i])
{
s=s+get(i,i+sz[dfw[i]]-1);
if(ls>=i+sz[dfw[i]]) p=merge(p,query(1,1,n,i+sz[dfw[i]],ls));
ls=i-1;
}
if(ls>l) p=merge(p,query(1,1,n,l,ls));
return(ds(p)+s);
}
int main()
{
cin>>n;
cf[0]=1;
fo(i,1,17) cf[i]=cf[i-1]*2,l2[cf[i]]=i;
fo(i,2,2*n) if(!l2[i]) l2[i]=l2[i-1];
fo(i,1,n-1)
{
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
link(x,y,w),link(y,x,w);
fx[m1]=m1-1,fx[m1-1]=m1;
}
dfs(1,0);
fo(i,1,dft[0]) rmq[i][0]=dft[i];
fo(j,1,17)
fo(i,1,dft[0])
rmq[i][j]=dmin(rmq[i][j-1],rmq[i+cf[j-1]][j-1]);
n1=1;
build(1,1,n);
cin>>q;
fo(t,1,q)
{
int c,c1=0;
scanf("%d",&c);
fo(i,1,c)
{
scanf("%d",&ask[i]),ask[i]=pt[ask[i]*2];
d1[++c1].x=dfn[ask[i]],d1[c1].y=ask[i];
d1[++c1].x=dfn[ask[i]]+sz[ask[i]],d1[c1].y=-1;
}
d1[++c1].x=d1[c1].y=1,d1[++c1].x=n+1,d1[c1].y=-1;
sort(d1+1,d1+c1+1);
int top=0;
fo(i,1,c1)
{
if(d1[i].y<0) st[top--]=0;
else
{
if(top) lk(st[top],d1[i].x);
st[++top]=d1[i].x;
}
}
printf("%lld\n",get(1,n));
fi[1]=0;
fo(i,1,c) fi[dfn[ask[i]]]=0;
}
}