这题显然是虚树,建出虚树后,先求出每个虚节点被谁支配。
这个可以两边dfs求出,第一遍从儿子转移到父亲,第二遍从父亲到儿子
然后对每条边考虑,如果被两个点被同一个点支配,显然这点加上
s
i
z
[
u
]
−
s
i
z
[
t
]
siz[u]-siz[t]
siz[u]−siz[t]
不然可以直接算出中间点的深度,倍增向上跳就可以了
一个点加上 s i z [ f ] − s i z [ m i d ] siz[f]-siz[mid] siz[f]−siz[mid],还有一个加上 s i z [ m i d ] − s i z [ t ] siz[mid]-siz[t] siz[mid]−siz[t]
还有不在虚树上的子树,要单独计算
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
#define dd c=getchar()
int read() {int s=0,w=1;char c;while (dd,c>'9' || c<'0') if (c=='-') w=-1;while (c>='0' && c<='9') s=s*10+c-'0',dd;return s*w;}
#undef dd
void write(int x) {if (x<0) putchar('-'),x=-x;if (x>=10) write(x/10);putchar(x%10|'0');}
void wln(int x) {write(x);putchar('\n');}void wsp(int x) {write(x);putchar(' ');}
const int N = 3e5+7, P=998244353;
struct edge {
int t,nxt;
}e[N<<1];
int n,Q,cnt,K,T;
int head[N],pos[N],siz[N],dep[N],fa[N][20],h[N],st[N],flg[N],num[N],bel[N];
vector<int> E[N];
void add(int u, int t) {
e[++cnt] = (edge) {t,head[u]}; head[u] = cnt;
}
bool cmp(int a, int b) {
return pos[a] < pos[b];
}
void dfs(int x, int f) {
pos[x] = ++T; siz[x] = 1;
dep[x] = dep[f] + 1; fa[x][0] = f;
for (int i = 1; i <=19; i++) fa[x][i] = fa[fa[x][i-1]][i-1];
for (int i = head[x]; i; i = e[i].nxt) {
int t = e[i].t;
if (t == f) continue;
dfs(t, x);
siz[x] += siz[t];
}
}
int LCA(int a, int b) {
if (dep[a] > dep[b]) swap(a, b);
for (int i = 19; ~i; i--) if (dep[fa[b][i]] >= dep[a]) b = fa[b][i];
if (a == b) return a;
for (int i = 19; ~i; i--) if (fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
return fa[a][0];
}
int dis(int a, int b) {return dep[a] + dep[b] - 2*dep[LCA(a, b)] + 1;}
void insert(int x) {
if (*st <= 1) {st[++*st] = x; return;}
int lca = LCA(st[*st], x);
while (*st > 1 && pos[st[*st-1]] >= pos[lca]) E[st[*st-1]].push_back(st[*st]), --*st;
if (lca != st[*st]) E[lca].push_back(st[*st]), st[*st] = lca;
st[++*st] = x;
}
void dfs1(int x) {
if (flg[x]) bel[x] = x;
else bel[x] = 0;
for (int i = 0; i < E[x].size(); i++) {
int t = E[x][i];
dfs1(t);
if (!bel[x]) bel[x] = bel[t];
else if (dep[bel[x]] > dep[bel[t]] || (dep[bel[x]] == dep[bel[t]] && bel[x] > bel[t])) bel[x] = bel[t];
}
}
void dfs2(int x) {
for (int i = 0; i < E[x].size(); i++) {
int t = E[x][i], d1 = dis(t, bel[x]), d2 = dis(t, bel[t]);
if (d1 < d2 || (d1 == d2 && bel[x] < bel[t])) bel[t] = bel[x];
dfs2(t);
}
}
int find(int x, int de) {
for (int i = 19; ~i; i--) if (dep[fa[x][i]] >= de) x = fa[x][i];
return x;
}
void dfs3(int x) {
int Sz = siz[x];
for (int i = 0, mi; i < E[x].size(); i++) {
int t = E[x][i], f = find(t, dep[x]+1);
if (f != t) {
if (bel[t] == bel[x]) {
num[flg[bel[x]]] += siz[f] - siz[t];
}
else {
int ln = dis(bel[x], bel[t]);
if (ln&1) {
mi = find(t, dep[bel[t]]-ln/2+(bel[t] > bel[x]));
}
else mi = find(t, dep[bel[t]]-ln/2+1);
num[flg[bel[x]]] += siz[f] - siz[mi];
num[flg[bel[t]]] += siz[mi] -siz[t];
}
}
dfs3(t);
Sz -= siz[f];
}
num[flg[bel[x]]] += Sz;
E[x].clear();
}
int main() {
// freopen("3572.in", "r", stdin);
// freopen("3572.out", "w", stdout);
n = read();
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
add(u, v); add(v, u);
}
dfs(1, 0);
Q = read();
while (Q--) {
K = read();
for (int i = 1; i <= K; i++) h[i] = read(), flg[h[i]] = i, num[i] = 0;
sort(h+1, h+K+1, cmp);
if (h[1] != 1) st[*st=1] = 1;
for (int i = 1; i <= K; i++) insert(h[i]);
while (*st > 1) E[st[*st-1]].push_back(st[*st]), --*st;
*st = 0;
dfs1(1);
dfs2(1);
dfs3(1);
for (int i = 1; i <= K; i++) flg[h[i]] = 0;
for (int i = 1; i <= K; i++) wsp(num[i]);
putchar('\n');
}
}