Problem
有一棵有n个节点的树,每次询问删掉k条边后剩下的k+1个联通块里的直径的和。
Solution
虚树典型套路,考虑怎么维护每个联通块的直径,显然把树看成有根树的话,每棵树因为子树的边被删掉从而在dfs序上变成了若干块,但是由于删掉一条边只会对其直接的父亲产生1的影响,所以加起来是有O(k)块的,那么就只要用线段树维护一下直径,打起来挺快的,一次过掉。
Code
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long LL;
typedef double db;
int get(){
char ch;
while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');
if(ch=='-'){
int s=0;
while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';
return -s;
}
int s=ch-'0';
while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';
return s;
}
const int N = 100010;
struct edge{
int id,x,w,nxt;
}e[N*2];
int h[N],tot;
int co[N];
int rmq[N*2][20],fir[N],u,dep[N],d[N];
int dfn[N],rig[N],k;
LL dis[N];
bool bz[N];
void inse(int x,int y,int z,int id){
e[++tot].x=y;
e[tot].id=id;
e[tot].w=z;
e[tot].nxt=h[x];
h[x]=tot;
}
void dfs(int x){
bz[x]=1;
d[dfn[x]=++k]=x;
rmq[fir[x]=++u][0]=x;
for(int p=h[x];p;p=e[p].nxt)
if (!bz[e[p].x]){
co[e[p].id]=e[p].x;
dep[e[p].x]=dep[x]+1;
dis[e[p].x]=dis[x]+e[p].w;
dfs(e[p].x);
rmq[++u][0]=x;
}
rig[x]=k;
}
void getrmq(){
fo(j,1,log(u)/log(2))
fo(i,1,u-(1<<j)+1)
if (dep[rmq[i][j-1]]<dep[rmq[i+(1<<(j-1))][j-1]])rmq[i][j]=rmq[i][j-1];
else rmq[i][j]=rmq[i+(1<<(j-1))][j-1];
}
int lca(int x,int y){
x=fir[x];y=fir[y];
if (x>y)swap(x,y);
int t=log(y-x+1)/log(2);
if (dep[rmq[x][t]]<dep[rmq[y-(1<<t)+1][t]])return rmq[x][t];
return rmq[y-(1<<t)+1][t];
}
LL getdis(int x,int y){
int t=lca(x,y);
return dis[x]+dis[y]-dis[t]*2;
}
struct lgs{
int a[2];
LL d;
};
lgs operator + (lgs a,lgs b){
if (a.d==-1)return b;
lgs c=a;
if (c.d<b.d)c=b;
fo(i,0,1)
fo(j,0,1){
LL tmp=getdis(a.a[i],b.a[j]);
if (c.d<tmp){
c.d=tmp;
c.a[0]=a.a[i];c.a[1]=b.a[j];
}
}
return c;
}
struct point{
lgs x;
int l,r;
}tree[N*2];
int tt;
int vis[N],sta[N],tp;
int n,q;
void build(int &now,int l,int r){
now=++tt;
if (l==r){
tree[now].x.a[0]=tree[now].x.a[1]=d[l];
return;
}
int mid=(l+r)/2;
build(tree[now].l,l,mid);
build(tree[now].r,mid+1,r);
tree[now].x=tree[tree[now].l].x+tree[tree[now].r].x;
}
int a[N],w[N],m;
bool is[N];
int L[N],R[N];
LL ans;
struct section{
int l,r;
}s[N];
int t;
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void add(int x,int y){R[y]=L[x];L[x]=y;}
lgs getv(int now,int l,int r,int x,int y){
if (x<=l&&r<=y)return tree[now].x;
int mid=(l+r)/2;
lgs ans;
ans.d=-1;
if (x<=mid)ans=ans+getv(tree[now].l,l,mid,x,y);
if (y>mid)ans=ans+getv(tree[now].r,mid+1,r,x,y);
return ans;
}
void getans(int x){
for(int y=L[x];y;y=R[y])getans(y);
if (is[x]){
section u;
lgs tmp;
tmp.a[0]=tmp.a[1]=tmp.d=-1;
u.l=dfn[x];u.r=rig[x];
while(s[t].l>=dfn[x]&&s[t].r<=rig[x]){
if (u.l<s[t].l)tmp=tmp+getv(1,1,n,u.l,s[t].l-1);
u.l=s[t].r+1;
t--;
}
if (u.l<=u.r)tmp=tmp+getv(1,1,n,u.l,u.r);
ans=ans+tmp.d;
s[++t].l=dfn[x];
s[t].r=rig[x];
}
}
void work(){
q=get();
fo(cas,1,q){
int nt=get();
fo(i,1,nt)is[a[i]=co[get()]]=1;
sort(a+1,a+1+nt,cmp);
is[1]=1;
sta[tp=1]=w[m=1]=1;
fo(i,1,nt){
int x=a[i],y=lca(x,sta[tp]);
while(dep[sta[tp]]>dep[y]){
if (tp>1&&dep[sta[tp-1]]>dep[y])add(sta[tp-1],sta[tp]);
else add(y,sta[tp]);
tp--;
}
if (sta[tp]!=y){sta[++tp]=y;w[++m]=y;}
if (sta[tp]!=x){sta[++tp]=x;w[++m]=x;}
is[x]=1;
}
fo(i,1,tp-1)add(sta[i],sta[i+1]);
ans=0;
t=0;
getans(1);
printf("%lld\n",ans);
fo(i,1,m){
int x=w[i];
L[x]=R[x]=is[x]=0;
}
}
}
int main(){
freopen("data.in","r",stdin);
freopen("data.out","w",stdout);
n=get();
fo(i,1,n-1){
int x=get(),y=get(),z=get();
inse(x,y,z,i);
inse(y,x,z,i);
}
dfs(1);
getrmq();
int rt;
build(rt,1,n);
work();
fclose(stdin);
fclose(stdout);
return 0;
}