Problem
题目概要:
给定一棵 n n 个节点的树,次询问:给定 m m 个关键点,每个原树上的点被最近且序号最小的关键点控制,问每个关键点()
Solution
观察数据限制:
∑m≤300000
∑
m
≤
300000
类似于这种总和限制的题目应该要想到虚树
先建出虚树,考虑如何统计答案
很容易想到一条边上的点只有可能被最靠近边两端点的关键点控制,所以应该求出虚树上每个点最接近的关键点
对于每一条边两端点为
x,y
x
,
y
,控制这两个点的标号为
tx,ty
t
x
,
t
y
:
如果
tx=ty
t
x
=
t
y
,则这条边上的贡献应全部属于
tx
t
x
如果 tx≠ty t x ≠ t y ,则肯定存在一个点 inv i n v 将这条边划分开来,其中 x x ~, inv i n v ~ y y 的贡献分别属于,这里可以用类似于lca的倍增求解
相应的,虚树上有没有体现的点,比如下图中黑色点为关键点,则绿色部分并不会在虚树中体现
所以这些点就必须将贡献交给最近的关键节点,设 rest[x] r e s t [ x ] 表示在虚树中未出现归到这个点的节点数量,则 rest[x] r e s t [ x ] 应等于 x x 在原树上的减去在虚树上的 sz s z
求lca还是树剖快,但在求切分点时应用lca方法求解,所以我打了树剖求lca加倍增求切分,我们推荐用 RMQ R M Q 来做
这题搞清楚细节就不是很难了
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rg register
template <typename _Tp> inline _Tp read(_Tp&x){
char c11=getchar(),ob=0;x=0;
while(c11^'-'&&!isdigit(c11))c11=getchar();if(c11=='-')c11=getchar(),ob=1;
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();if(ob)x=-x;return x;
}
const int N=301000,M=21;
struct EDGE{int v,nxt;}a[N<<1];
int head[N],anc[N][M],sz[N],son[N],ltop[N],fa[N],depth[N],dfn[N],num[N],f[N],sta[N],vis[N];
int b[N],c[N],rest[N];
int n,m,Q,_,dfc,tot,top;
inline int cmp(const int&AA,const int&BB){return dfn[AA]<dfn[BB];}
inline void add(int u,int v){a[++_].v=v,a[_].nxt=head[u],head[u]=_;}
inline void dfs1(int x,int dad){
dfn[x]=++dfc;anc[x][0]=fa[x]=dad;
for(rg int i=1;i<20;++i)anc[x][i]=anc[anc[x][i-1]][i-1];
depth[x]=depth[dad]+1;
sz[x]=1;
int mxsz(0);
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=dad){
dfs1(a[i].v,x);
sz[x]+=sz[a[i].v];
if(sz[a[i].v]>mxsz)son[x]=a[i].v,mxsz=sz[a[i].v];
}
return ;
}
inline void dfs2(int x,int Top){
ltop[x]=Top;
if(son[x])dfs2(son[x],Top);
for(int i=head[x];i;i=a[i].nxt)
if(a[i].v!=fa[x]&&a[i].v!=son[x])
dfs2(a[i].v,a[i].v);
return ;
}
inline int lca(int x,int y){
while(ltop[x]!=ltop[y]){
if(depth[ltop[y]]>depth[ltop[x]])swap(x,y);
x=fa[ltop[x]];
}
if(depth[x]>depth[y])swap(x,y);
return x;
}
inline int dis(int x,int y){return depth[x]+depth[y]-depth[lca(x,y)]*2;}
inline void dfs3(int x){
vis[++tot]=x,rest[x]=sz[x];
for(int i=head[x];i;i=a[i].nxt){
dfs3(a[i].v);
if(num[a[i].v]){
int t1=dis(x,num[x]),t2=dis(x,num[a[i].v]);
if(!num[x]||t1>t2||(t1==t2&&num[x]>num[a[i].v]))num[x]=num[a[i].v];
}
}return ;
}
inline void dfs4(int x){
for(int i=head[x];i;i=a[i].nxt){
int t1=dis(num[a[i].v],a[i].v),t2=dis(num[x],a[i].v);
if(!num[a[i].v]||t1>t2||(t1==t2&&num[x]<num[a[i].v]))num[a[i].v]=num[x];
dfs4(a[i].v);
}return ;
}
inline void count(int x,int y){
int ty=y;
while(ltop[x]!=ltop[ty])
if(fa[ltop[ty]]==x){ty=ltop[ty];break;}
else ty=fa[ltop[ty]];
if(ltop[x]==ltop[ty])ty=son[x];
rest[x]-=sz[ty];
if(num[x]==num[y]){f[num[x]]+=sz[ty]-sz[y];return ;}
int inv=y,t1,t2;
for(rg int i=19;~i;--i)
if(depth[anc[inv][i]]>depth[x]){
t1=dis(anc[inv][i],x);
t2=dis(anc[inv][i],y);
if(t1>t2||(t1==t2&&num[x]>num[y]))inv=anc[inv][i];
}
f[num[x]]+=sz[ty]-sz[inv];
f[num[y]]+=sz[inv]-sz[y];
return ;
}
int main(){
freopen("in","r",stdin);
read(n);
for(rg int i=1,x,y;i<n;++i)read(x),read(y),add(x,y),add(y,x);
dfs1(1,0);dfs2(1,1);_=0;
for(rg int i=1;i<=n;++i)head[i]=0;
read(Q);
while(Q--){
read(m);top=_=tot=0;
for(rg int i=1;i<=m;++i){c[i]=read(b[i]);num[c[i]]=c[i];}
sort(b+1,b+m+1,cmp);
if(num[1]!=1)sta[top=1]=1;
for(rg int i=1;i<=m;++i){
int x=b[i],o=0;
while(top){
o=lca(x,sta[top]);
if(top>1&&depth[o]<depth[sta[top-1]]){add(sta[top-1],sta[top]);--top;}
else if(depth[o]<depth[sta[top]]){add(o,sta[top]);--top;break;}
else break;
}
if(sta[top]!=o)sta[++top]=o;
sta[++top]=x;
}
while(top>1){add(sta[top-1],sta[top]);--top;}
dfs3(1);dfs4(1);
for(rg int x=1;x<=tot;++x)
for(rg int i=head[vis[x]];i;i=a[i].nxt)
count(vis[x],a[i].v);
for(rg int i=1;i<=tot;++i)f[num[vis[i]]]+=rest[vis[i]];
for(rg int i=1;i<m;++i)
printf("%d ",f[c[i]]);
printf("%d\n",f[c[m]]);
for(rg int i=1;i<=tot;++i)
head[vis[i]]=f[vis[i]]=num[vis[i]]=0;
}return 0;
}