世界树......冰封王座肯定不是临时议事处。
网上神犇的题解都说要用虚树, 我试了各种办法都没搜到跟虚树有关系的东西QAQ。 终于, 在贴吧大神和DG的帮助下搞懂了这道题。 虽然神犇们不说什么是虚树, 我在这里介绍这一题中的应用。(当然, 有些大神其实根本不知道什么是虚树也乱发题解了)
对于每次询问的m个点, 在原树上将它们连接起来形成一个子图, 同时把这m个点的lca加入子图, 对非询问点而度数为2的点只要缩掉, 就得到了与当前询问对应的一棵虚树。(个人感觉不用在意这个名字, 只是重新构造了一个图而已)
这时我们会发现对于虚树中的每一条边, 其两端的节点都有自己的归属(这是句废话), 那么讨论这条链上的点的归属情况, 最后统计即可得到答案。
建虚树的过程请看代码......反正我当时也是看代码总结的, 总之就是用栈维护一条最右链, 每次用还未加入的元素以及该元素和栈顶的lca更新就行了。
ps: 见到奇奇怪怪的变量名称不要被吓到, 其实是有意义的。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define N 300000 + 10
#define H 20
#define INF 1000000000
using namespace std;
struct node
{
int dist, bel;
node() { }
node(int x, int y)
{
dist = x;
bel = y;
}
}g[N];
struct edge
{
int to, next;
}e[2*N];
int n, m, q, num, top, ind, cont, p[N], flag[N];
int fa[N][H], fa_v[N], r[N], d[N], h[N], s[N], st[N];
int sum[N], dfn[N], ans[N], val[N], delta[N];
inline bool cmp(int x, int y)
{ return dfn[x] < dfn[y]; }
node min(node a, node b)
{
if (a.dist == b.dist) return a.bel < b.bel ? a : b;
return a.dist < b.dist ? a : b;
}
void read(int &x)
{
x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9')
{
x = 10*x + c - '0';
c = getchar();
}
}
void add(int x, int y)
{
e[++num].to = y;
e[num].next = p[x];
p[x] = num;
}
void init()
{
int x, y;
read(n);
for (int i = 1; i < n; ++i)
{
read(x), read(y);
add(x, y);
add(y, x);
}
}
void bfs()
{
queue<int>q;
q.push(1);
fa[1][0] = 1;
flag[1] = 1;
while(!q.empty())
{
int x = q.front();
q.pop();
for (int i = 1; i < H; ++i)
fa[x][i] = fa[fa[x][i-1]][i-1];
for (int i = p[x]; i; i = e[i].next)
{
int k = e[i].to;
if (!flag[k])
{
d[k] = d[x] + 1;
fa[k][0] = x;
flag[k] = 1;
q.push(k);
}
}
}
}
void dfs(int x)
{
sum[x] = 1;
dfn[x] = ++ind;
for (int i = p[x]; i; i = e[i].next)
{
int k = e[i].to;
if (k != fa[x][0])
{
dfs(k);
sum[x] += sum[k];
}
}
}
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y);
int l = x, r = y;
for (int mid = d[r] - d[l], i = 0; mid; ++i, mid >>= 1)
if (mid & 1) r = fa[r][i];
if (l == r) return r;
for (int i = H - 1; i >= 0; i--)
{
if (fa[l][i] == fa[r][i]) continue;
l = fa[l][i], r = fa[r][i];
}
return fa[r][0];
}
int find(int x, int h)
{
for (int i = H - 1; i >= 0; i--)
if (d[fa[x][i]] >= h) x = fa[x][i];
return x;
}
void solve()
{
top = cont = 0;
read(m);
for (int i = 1; i <= m; ++i)
{
read(h[i]);
s[++cont] = r[i] = h[i];
g[h[i]] = node(0, h[i]);
ans[h[i]] = 0;
}
sort(h+1, h+m+1, cmp);
for (int i = 1; i <= m; ++i)
{
int x = h[i];
if (!top)
{
st[++top] = x;
fa_v[x] = 0;
}
else
{
int anc = lca(x, st[top]);
while(d[st[top]] > d[anc])
{
if (d[st[top-1]] <= d[anc])
fa_v[st[top]] = anc;
top--;
}
if (st[top] != anc)
{
s[++cont] = anc;
g[anc] = node(INF, 0);
fa_v[anc] = st[top];
st[++top] = anc;
}
fa_v[x] = anc;
st[++top] = x;
}
}
sort(s+1, s+cont+1, cmp);
for (int i = 1; i <= cont; ++i)
{
int x = s[i];
val[x] = sum[x];
if (i > 1) delta[x] = d[x] - d[fa_v[x]];
}
for (int i = cont; i > 1; i--)
{
int x = s[i];
g[fa_v[x]] = min(g[fa_v[x]], node(g[x].dist+delta[x], g[x].bel));
}
for (int i = 2; i <= cont; ++i)
{
int x = s[i];
g[x] = min(g[x], node(g[fa_v[x]].dist+delta[x], g[fa_v[x]].bel));
}
for (int i = 1; i <= cont; ++i)
{
int x = s[i], anc = fa_v[x];
if (i == 1) ans[g[x].bel] += n - sum[x];
else
{
int k = find(x, d[anc]+1);
int del = sum[k] - sum[x];
val[anc] -= sum[k];
if (g[anc].bel == g[x].bel) ans[g[x].bel] += del;
else
{
int mid = d[x] - ((g[x].dist+g[fa_v[x]].dist+delta[x])/2-g[x].dist);
if ((g[x].dist+g[fa_v[x]].dist+delta[x]) % 2 == 0 && g[x].bel > g[fa_v[x]].bel) ++mid;
int tmp = sum[find(x, mid)] - sum[x];
ans[g[x].bel] += tmp;
ans[g[fa_v[x]].bel] += del - tmp;
}
}
}
for (int i = 1; i <= cont; ++i)
ans[g[s[i]].bel] += val[s[i]];
for (int i = 1; i <= m; ++i)
printf("%d ", ans[r[i]]);
putchar('\n');
}
void deal()
{
bfs();
dfs(1);
read(q);
while(q--) solve();
}
int main()
{
init();
deal();
return 0;
}