题目:
https://www.luogu.org/problemnew/show/P3233
分析:
先把虚树建出来,顺便把根节点插入到虚树中。
然后进行dp求出虚树上每个节点到最近关键点的编号。可以两次bfs求,因为有可能是他的兄弟最近。我们先求儿子对父亲的影响,再让父亲更新儿子,这样就可以求出来了。
考虑怎么求答案。对于每一个虚树节点维护一个
s
u
m
[
x
]
sum[x]
sum[x],先把他赋值成
s
i
z
e
[
x
]
size[x]
size[x]。如果一条边两端最近关键点相同,那么使父亲的
s
u
m
sum
sum减去儿子的
s
i
z
e
size
size;如果不同,那么倍增出一个分界点
d
d
d,父亲减去
s
i
z
e
[
d
]
size[d]
size[d],儿子加上
s
i
z
e
[
d
]
−
s
i
z
e
[
s
o
n
]
size[d]-size[son]
size[d]−size[son]。
最后一个关键点的答案就是虚树上以他作为最近关键点的节点的
s
u
m
sum
sum的和。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
const int maxn=3e5+7;
using namespace std;
int n,m,T,x,y,cnt,top;
int ls[maxn],f[maxn][20],dfn[maxn],dep[maxn],size[maxn];
int a[maxn],bel[maxn],sum[maxn],ans[maxn],q[maxn],pre[maxn];
struct edge{
int y,next;
}g[maxn*2];
struct node{
int x,num;
}b[maxn];
bool cmpa(int a,int b)
{
return dfn[a]<dfn[b];
}
bool cmpb(node a,node b)
{
return dfn[a.x]<dfn[b.x];
}
bool cmpc(node a,node b)
{
return a.num<b.num;
}
void add(int x,int y)
{
g[++cnt]=(edge){y,ls[x]};
ls[x]=cnt;
}
void dfs(int x,int fa)
{
dfn[x]=++cnt;
size[x]=1;
f[x][0]=fa;
dep[x]=dep[fa]+1;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (y==fa) continue;
dfs(y,x);
size[x]+=size[y];
}
}
int up(int x,int d)
{
int k=19,t=1<<k;
while (d)
{
if (d>=t) d-=t,x=f[x][k];
t/=2,k--;
}
return x;
}
int getlca(int x,int y)
{
if (dep[x]>dep[y]) swap(x,y);
int d=dep[y]-dep[x];
y=up(y,d);
if (x==y) return x;
int k=19;
while (k>=0)
{
if (f[x][k]!=f[y][k])
{
x=f[x][k];
y=f[y][k];
}
k--;
}
return f[x][0];
}
int getdis(int x,int y)
{
int d=getlca(x,y);
return dep[x]+dep[y]-2*dep[d];
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
cnt=0;
dfs(1,0);
for (int j=1;j<20;j++)
{
for (int i=1;i<=n;i++) f[i][j]=f[f[i][j-1]][j-1];
}
scanf("%d",&T);
while (T--)
{
scanf("%d",&m);
for (int i=1;i<=m;i++) scanf("%d",&b[i].x),b[i].num=i;
sort(b+1,b+m+1,cmpb);
top=cnt=0;
a[++cnt]=1;
q[++top]=1;
bel[1]=0;
for (int i=1;i<=m;i++)
{
if (b[i].x==1)
{
bel[1]=1;
continue;
}
a[++cnt]=b[i].x;
bel[b[i].x]=b[i].x;
int d=getlca(q[top],b[i].x);
if (d==q[top]) q[++top]=b[i].x;
else
{
while (dep[q[top-1]]>=dep[d])
{
pre[q[top]]=q[top-1];
top--;
}
if (q[top]!=d)
{
pre[q[top]]=d;
q[top]=d;
a[++cnt]=d;
bel[d]=0;
}
q[++top]=b[i].x;
}
}
while (top>1) pre[q[top]]=q[top-1],top--;
sort(a+1,a+cnt+1,cmpa);
for (int i=cnt;i>1;i--)
{
int len1=dep[bel[a[i]]]-dep[pre[a[i]]];
int len2=dep[bel[pre[a[i]]]]-dep[pre[a[i]]];
if ((!bel[pre[a[i]]]) || (len1<len2) || ((len1==len2) && (bel[pre[a[i]]]>bel[a[i]]))) bel[pre[a[i]]]=bel[a[i]];
}
for (int i=2;i<=cnt;i++)
{
int len1=getdis(a[i],bel[pre[a[i]]]);
int len2=getdis(a[i],bel[a[i]]);
if ((!bel[a[i]]) || (len1<len2) || ((len1==len2) && (bel[a[i]]>bel[pre[a[i]]]))) bel[a[i]]=bel[pre[a[i]]];
}
for (int i=1;i<=cnt;i++) ans[a[i]]=0,sum[a[i]]=size[a[i]];
for (int i=2;i<=cnt;i++)
{
if (bel[a[i]]==bel[pre[a[i]]]) sum[pre[a[i]]]-=size[a[i]];
else
{
int len=getdis(bel[a[i]],bel[pre[a[i]]]);
int d=up(bel[a[i]],(len-1)/2);
if ((len%2==0) && (bel[a[i]]<bel[pre[a[i]]])) d=f[d][0];
sum[pre[a[i]]]-=size[d];
sum[a[i]]+=size[d]-size[a[i]];
}
}
for (int i=1;i<=cnt;i++) ans[bel[a[i]]]+=sum[a[i]];
sort(b+1,b+m+1,cmpc);
for (int i=1;i<=m;i++) printf("%d ",ans[b[i].x]);
printf("\n");
}
}