学习了一种新的树形结构——虚树,就是把每个树上是链的部分缩起来,再对这些在虚树上的点进行操。
这题其实几天前就A掉了。。。只是题解还没写……
这题有个很直观的想法,就是把询问点按dfs序排序,然后只保留相邻点的lca。然后再乱搞。。。
然后有个很直观的感觉就是不可打(当然这是因为我太弱了。。。)
然后去膜拜ydc神犇的代码。
然后发现非常好写!
然后还明白了这个东西叫做虚树。。。
具体做法是:
先建出虚树:求出lca,然后模拟dfs,维护每个点的父亲结点。如果栈顶不是下一个点的祖先就出栈。记得如果在一条边中间插入要更新下面那个点的父亲信息。
具体代码如下:
h[N]是读入的点,d[N]是深度,t[N]是虚树的结点。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <algorithm>
using namespace std;
#define pii pair<int,int>
#define mp make_pair
#define INF 2100000000
#define maxn 310000
#define deg 20
#define pb(a) push_back(a)
int getint()
{
int res;char c;
while(c=getchar(),c<'0'||c>'9');
res=c-'0';
while(c=getchar(),c>='0'&&c<='9')
res=res*10+c-'0';
return res;
}
int en,indexs,n,fa[maxn][deg],pre[maxn],id[maxn],siz[maxn];
int to[maxn<<1],first[maxn<<1],next[maxn<<1],dep[maxn],dfn[maxn];
void build(int a,int b)
{
en++;
to[en]=b;
next[en]=first[a];
first[a]=en;
}
void dfs(int now)
{
int v;
indexs++;siz[now]=1;
dfn[now]=indexs;
for(int i=first[now];i;i=next[i])
{
v=to[i];
if(v==fa[now][0]) continue;
fa[v][0]=now;
for(int j=1;j<deg;j++) fa[v][j]=fa[fa[v][j-1]][j-1];
dep[v]=dep[now]+1;
dfs(v);
siz[now]+=siz[v];
}
}
int len,mem[maxn],val[maxn],w[maxn];
int ans[maxn];
int h[maxn];
vector<int>save;
bool cmp(int a,int b)
{
return dfn[a]<dfn[b];
}
int lca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
int tv=v,tu=v;
for(int i=19;i>=0;i--)
{
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
}
if(u==v)return v;
for(int i=19;i+1;i--)
if(fa[u][i]!=fa[v][i])
{
u=fa[u][i];
v=fa[v][i];
}
return fa[u][0];
}
pii g[maxn];
int num;
int sta[maxn],top;
void virtree()
{
top=0;
for(int i,j=1;j<=len;j++)
{
i=h[j];
if(top==0)
{
top++;
sta[top]=i;
pre[i]=0;
continue;
}
int anc=lca(i,sta[top]);
for(;top;top--)
{
if(dep[anc]>=dep[sta[top]]) break;
if(dep[sta[top-1]]<=dep[anc])
{
pre[sta[top]]=anc;
}
}
if(sta[top]!=anc)
{
save.pb(anc);
pre[anc]=sta[top];
top++;sta[top]=anc;
g[anc]=mp(INF,0);
}
pre[i]=anc;top++;
sta[top]=i;
}
}
inline int get(int x,int d)
{
for(int i=19;i+1;i--)
if(dep[fa[x][i]]>=d)x=fa[x][i];
return x;
}
void solve()
{
save.clear();
int a,b,c,num;
len=getint();
for(int i=1;i<=len;i++)
{
h[i]=getint();
mem[i]=h[i];
save.pb(h[i]);
ans[h[i]]=0;
g[h[i]]=mp(0,h[i]);
}
ans[0]=0;
sort(h+1,h+len+1,cmp);
virtree();
sort(save.begin(),save.end(),cmp);
num=save.size();
for(int i=1;i<=num;i++)
{
int z=save[i-1];
val[z]=siz[z];
if(i>1)
w[z]=dep[z]-dep[pre[z]];
}
for(int i=num;i>1;i--)
{
int x=save[i-1];
int father=pre[x];
g[father]=min(g[father],mp(g[x].first+w[x],g[x].second));
}
for(int i=2;i<=num;i++)
{
int x=save[i-1],father=pre[x];
g[x]=min(g[x],mp(g[father].first+w[x],g[father].second));
}
for(int i=1;i<=num;i++)
{
int x=save[i-1],father=pre[x];
if(i==1)
{
ans[g[x].second]+=n-siz[x];
continue;
}
int G=get(x,dep[father]+1);
int sum=siz[G]-siz[x];
val[father]-=siz[G];
if(g[x].second==g[father].second)
{
ans[g[x].second]+=sum;
continue;
}
int mid=dep[x]-((g[father].first+g[x].first+w[x])/2-g[x].first);
if(!((g[father].first+g[x].first+w[x])&1)&&g[father].second<g[x].second)++mid;
int y=siz[get(x,mid)]-siz[x];
ans[g[x].second]+=y;
ans[g[father].second]+=sum-y;
}
for(int i=1;i<=num;i++)
ans[g[save[i-1]].second]+=val[save[i-1]];
for(int i=1;i<=len;i++)
printf("%d ",ans[mem[i]]);
printf("\n");
}
int main()
{
int a,b,c;
n=getint();
for(int i=1;i<n;i++)
{
a=getint();
b=getint();
build(a,b);
build(b,a);
}
dep[1]=1;
dfs(1);
int cas=getint();
while(cas--)
{
solve();
}
return 0;
}
/*
10
2 1
3 2
4 3
5 4
6 1
7 3
8 3
9 4
10 1
5
2
6 1
5
2 7 3 6 9
1
8
4
8 7 10 3
5
2 9 3 5 8
*/