题目大意
有一棵 n n n个点的树,每个点有一个编号,有 q q q次操作。对于每次操作,给出 m m m个点并称为议事点,树上各个点由离这个点最近的议事点管理(如果有多个议事点离这个点最近,则这个点由这些议事点中编号最小的点管理)。求每次操作中每个议事点管理的点的数目。
1 ≤ n ≤ 3 × 1 0 5 , 1 ≤ q ≤ 3 × 1 0 5 , 1 ≤ ∑ m i ≤ 3 × 1 0 5 1\leq n\leq 3\times 10^5,1\leq q\leq 3\times 10^5,1\leq \sum m_i\leq 3\times 10^5 1≤n≤3×105,1≤q≤3×105,1≤∑mi≤3×105
题解
前置知识:虚树
如果直接 d f s dfs dfs的话,则要通过两遍 d f s dfs dfs。对于一个非议事点 u u u,先用 u u u的子树中离 u u u最近的议事点来更新 u u u,再用 u u u的子树之外的离 u u u最近的议事点来更新 u u u。
下面是暴力的代码。
code
void dfs1(int u,int fa){
dis[u]=inf;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs(d[i],u);
if(dis[d[i]]+1<dis[u]){
dis[u]=dis[d[i]]+1;
to[u]=to[d[i]];
}
else if(dis[d[i]]+1==dis[u]){
to[u]=min(to[u],to[d[i]]);
}
}
if(z[u]){
dis[u]=0;to[u]=u;
}
}
void dfs2(int u,int fa){
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
if(dis[u]+1<dis[d[i]]){
dis[d[i]]=dis[u]+1;
to[d[i]]=to[u];
}
else if(dis[u]+1==dis[d[i]]){
to[d[i]]=min(to[d[i]],to[u]);
}
dfs2(d[i],u);
}
z[u]=0;
}
因为 1 ≤ ∑ m i ≤ 3 × 1 0 5 1\leq \sum m_i\leq 3\times 10^5 1≤∑mi≤3×105,所以我们可以用虚树来解决问题。
对于每次操作,建一棵虚树。对于虚树上的点,用上面的 d f s dfs dfs跑一Bianc遍即可。
那不在虚树上的点怎么判断呢?
我们要分两种情况:
- 如果原树上有一个点 u u u,它有一个儿子 v v v,且 v v v的子树中没有议事点,则子树 v v v中的每个点肯定都是由离 u u u最近的议事点管理
- 如果虚树上有一个点 u u u,它有一个儿子 v v v,且 u u u和 v v v之间含有若干个点,则用倍增求 u u u和 v v v之间的这条链上的分界点,分界点及以下的部分由离 v v v最近的议事点管理,分界点以上的部分由离 u u u最近的议事点管理,注意链上附带的点也根据在分界点的上面或下面来判断由哪个议事点管理
对于第一种情况,在 d f s dfs dfs时对每个议事点的答案加上其子树的大小即可。
对于第二种情况,将分界点 x x x及以下的点的数量(即 s i z [ x ] − s i z [ v ] siz[x]-siz[v] siz[x]−siz[v])加在离 v v v最近的议事点上,将分界点以上的点的数量(因为在第一种情况中已经加上了每个议事点的子树大小,所以只需要减去 s i z [ x ] siz[x] siz[x])加在离 u u u最近的议事点上。
时间复杂度为 O ( n log n + ∑ m log m ) O(n\log n+\sum m\log m) O(nlogn+∑mlogm)。
code
#include<bits/stdc++.h>
using namespace std;
const int N=300005,inf=1e9;
int n,q,k,tot=0,wt=0,top=0,d[N*2],l[N*2],r[N*2],a[N],b[N];
int dep[N],dfn[N],siz[N],s[N],z[N],dis[N],to[N],ans[N],f[N][20];
vector<int>v[N];
int cmp(int ax,int bx){
return dfn[ax]<dfn[bx];
}
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void pt(int u,int fa){
dfn[u]=++wt;
dep[u]=dep[fa]+1;
siz[u]=1;
f[u][0]=fa;
for(int i=1;i<=19;i++){
f[u][i]=f[f[u][i-1]][i-1];
}
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
pt(d[i],u);
siz[u]+=siz[d[i]];
}
}
int gt(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=19;i>=0;i--){
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
}
if(x==y) return x;
for(int i=19;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
void insert(int x){
if(top==1){
s[++top]=x;
return;
}
int lca=gt(x,s[top]);
while(top>1&&dfn[s[top-1]]>=dfn[lca]){
v[s[top-1]].push_back(s[top]);--top;
}
if(s[top]!=lca){
v[lca].push_back(s[top]);
s[top]=lca;
}
s[++top]=x;
}
void dd(int x,int y){
int u=y;
for(int i=19;i>=0;i--){
int v1,v2;
v1=dep[y]-dep[f[u][i]]+dis[y];
v2=dep[f[u][i]]-dep[x]+dis[x];
if(dep[f[u][i]]>dep[x]&&(v1<v2||v1==v2&&to[y]<to[x])) u=f[u][i];
}
ans[to[y]]+=siz[u]-siz[y];
ans[to[x]]-=siz[u];
}
void dfs1(int u,int fa){
dis[u]=inf;
for(int i=0;i<v[u].size();i++){
int e=v[u][i];
dfs1(e,u);
int vt=dep[e]-dep[u];
if(dis[e]+vt<dis[u]){
dis[u]=dis[e]+vt;
to[u]=to[e];
}
else if(dis[e]+vt==dis[u]){
to[u]=min(to[u],to[e]);
}
}
if(z[u]){
dis[u]=0;to[u]=u;
}
}
void dfs2(int u,int fa){
for(int i=0;i<v[u].size();i++){
int e=v[u][i];
int vt=dep[e]-dep[u];
if(dis[u]+vt<dis[e]){
dis[e]=dis[u]+vt;
to[e]=to[u];
}
else if(dis[u]+vt==dis[e]){
to[e]=min(to[e],to[u]);
}
dd(u,e);
dfs2(e,u);
}
ans[to[u]]+=siz[u];
z[u]=0;
v[u].clear();
}
int main()
{
scanf("%d",&n);
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
pt(1,0);
scanf("%d",&q);
while(q--){
scanf("%d",&k);
for(int i=1;i<=k;i++){
scanf("%d",&a[i]);b[i]=a[i];
z[a[i]]=1;ans[a[i]]=0;
}
sort(a+1,a+k+1,cmp);
s[top=1]=1;
for(int i=1;i<=k;i++){
if(a[i]!=1) insert(a[i]);
}
while(top>1){
v[s[top-1]].push_back(s[top]);--top;
}
dfs1(1,0);dfs2(1,0);
for(int i=1;i<=k;i++){
printf("%d ",ans[b[i]]);
}
printf("\n");
}
return 0;
}