首先明显是个虚树,这不多说了。
然后考虑虚树上怎么做这个东西。
类比一下,抽象成每个关键点有个人在跑,看谁跑的快那个点是谁的。
首先虚树叶子的子树肯定归这个节点管。进一步的,每个点的儿子子树如果不在虚树上,肯定归它管。
但是虚树上不只有关键点,还有 LCA(注意这里我已开始脑抽了,以为虚树上非关键点上方是不会有关键点的)。
所以我们得记一下虚树上每个点最近的关键点 b l i bl_i bli,这样就讨论完了第一种情况。
然后考虑往上转移,当前讨论一对父子节点 ( u , v ) (u,v) (u,v) ,如果两者最近的关键点是一样的,从 v v v 跑到 u u u 的路径上,包括路径上挂着的其他节点,最近的节点肯定是 b l u bl_u blu。这个贡献是 s z f v − s z v sz_{fv}-sz_v szfv−szv( f v fv fv 表示 u u u 管辖 v v v 的子节点,这个倍增即可)。
否则, v v v 到 u u u 的路径上一定有一些节点归 b l v bl_v blv 管辖,另一些归 b l u bl_u blu 管辖而且肯定连续,不难想到倍增跳,于是也很好搞。
这题就完了,但是实现起来细节贼多。
- 数组清空
- 算 b l bl bl 数组的时候容易写挂,最好按着定义来写而不是 DP。
- 倍增跳的时候不能跳到 u u u 上方。
又臭又长的代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int n,m;
struct Edge{
int u,v,w,nxt;
}e[maxn];
int nd[maxn];
struct node{
int w,id;
}a[maxn];
inline bool operator <(node x,node y){
return x.w<y.w;
}
inline bool operator ==(node x,node y){
return (x.id==y.id);
}
const int inf=1e7+5;
vector<int>G[maxn];
stack<int>S;
int h[maxn],cnt=0;
void add(int u,int v,int w){
// cout<<"add"<<endl;
e[++cnt].u=u;
e[cnt].v=v;
e[cnt].w=w;
e[cnt].nxt=h[u];
h[u]=cnt;
}
int ans[maxn];
int sz[maxn],dfn[maxn],dep[maxn],low[maxn],Log[maxn],rt[maxn][19],mn[maxn][19],dc=0;
inline bool bel(int x,int y){
return (dfn[x]<=dfn[y]&&dfn[y]<=low[x]);
}
inline int LCA(int u,int v){
if(dep[u]<dep[v])swap(u,v);
while(dep[u]>dep[v])u=rt[u][Log[dep[u]-dep[v]]];
if(u==v)return u;
for(int i=17;i>=0;i--){
if(rt[u][i]!=rt[v][i]&&rt[u][i]&&rt[v][i]){
u=rt[u][i];
v=rt[v][i];
}
}
return rt[u][0];
}
int dis(int u,int v){
return dep[u]+dep[v]-2*dep[LCA(u,v)];
}
bool core[maxn],mark[maxn];
int rn=0;
int f[maxn],bl[maxn],g[maxn],blg[maxn];
inline int go_up(int u,int p){
for(int i=17;i>=0;i--){
if((1<<i)<=p&&rt[u][i]){
p-=(1<<i);
u=rt[u][i];
}
}
return u;
}
void dfs2(int u,int fa){
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa)continue;
if(bl[v]==-1){
bl[v]=bl[u];
}
else{
int d1=dis(v,bl[u]),d2=dis(v,bl[v]);
if(d1<d2||(d1==d2&&bl[v]>bl[u]))bl[v]=bl[u];
}
dfs2(v,u);
}
}
void dfs1(int u,int fa){
if(core[u])bl[u]=u;
else bl[u]=-1;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa)continue;
// cout<<u<<"->"<<v<<endl;
dfs1(v,u);
if(bl[v]!=-1){
if(bl[u]==-1)bl[u]=bl[v];
else{
int d1=dis(u,bl[v]),d2=dis(u,bl[u]);
if(d1<d2||(d1==d2&&bl[u]>bl[v]))bl[u]=bl[v];
}
}
}
return ;
}
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
sz[u]=1;
rt[u][0]=fa;
for(int i=1;i<=Log[dep[u]];i++){
rt[u][i]=rt[rt[u][i-1]][i-1];
}
dfn[u]=++dc;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v!=fa){
dfs(v,u);
sz[u]+=sz[v];
}
}
low[u]=dc;
}
int solve(int u,int v){
int fv=v;
for(int i=17;i>=0;i--){
if(rt[v][i]&&dep[rt[v][i]]>dep[u]){
if((dis(bl[u],rt[v][i])<dis(bl[fv],rt[v][i]))||((dis(bl[u],rt[v][i])==dis(bl[fv],rt[v][i]))&&(bl[u]<bl[fv]))){
}
else{
v=rt[v][i];
}
}
}
return v;
}
void dfs3(int u,int fa){
// cout<<u<<"<->"<<bl[u]<<endl;
int S=sz[u];
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa)continue;
dfs3(v,u);
int rv=go_up(v,dep[v]-dep[u]-1);
// cout<<u<<" RV "<<rv<<endl;
S-=sz[rv];
if(bl[u]==bl[v]){
ans[bl[u]]+=(sz[rv]-sz[v]);
}
else{
int mid=solve(u,v);
ans[bl[u]]+=(sz[rv]-sz[mid]);
ans[bl[v]]+=(sz[mid]-sz[v]);
}
}
ans[bl[u]]+=S;
}
inline bool cmp(node a,node b){
return a.id<b.id;
}
signed main(){
Log[0]=-1;
for(int i=1;i<maxn;i++){
Log[i]=Log[i/2]+1;
}
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v,w;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
// dep[0]=1;
dfs(1,0);
scanf("%d",&m);
for(int i=1;i<=m;i++){
// memset(ans,0,sizeof(ans));
cnt=0;
int k;
scanf("%d",&k);
rn=k;
while(S.size())S.pop();
for(int i=1;i<=k;i++){
int u;
scanf("%d",&u);
nd[i]=u;
ans[u]=0;
a[i].id=u;
a[i].w=dfn[u];
core[u]=1;
}
int rk=k;
sort(a+1,a+k+1);
for(int i=1;i<k;i++){
a[k+i].id=LCA(a[i].id,a[i+1].id);
a[k+i].w=dfn[a[k+i].id];
}
k+=(k);
a[k].id=1;
a[k].w=dfn[1];
sort(a+1,a+k+1,cmp);
int tk=unique(a+1,a+k+1)-a-1;
sort(a+1,a+tk+1);
for(int i=1;i<=tk;i++){
while(S.size()&&!bel(S.top(),a[i].id))S.pop();
if(S.size())add(S.top(),a[i].id,dep[a[i].id]-dep[S.top()]);
S.push(a[i].id);
// cout<<a[i].id<<" ";
}
// cout<<a[1].id<<endl;
dfs1(a[1].id,0);
dfs2(a[1].id,0);
dfs3(a[1].id,0);
for(int i=1;i<=rk;i++){
printf("%d ",ans[nd[i]]);
}
printf("\n");
for(int i=1;i<=tk;i++){
core[a[i].id]=0;
h[a[i].id]=0;
mark[a[i].id]=0;
ans[a[i].id]=0;
f[a[i].id]=inf;
g[a[i].id]=inf;
bl[a[i].id]=0;
blg[a[i].id]=0;
}
}
}