【题目链接】
【思路要点】
- 补档博客,无题解。
【代码】
#include<bits/stdc++.h> using namespace std; #define MAXN 300005 #define MAXLOG 20 #define INF 1000000000 struct info {int pos, dist; }; info dmin[MAXN], dnim[MAXN], up[MAXN], home[MAXN]; bool mark[MAXN]; int timer, size[MAXN], depth[MAXN], dfn[MAXN]; int father[MAXN][MAXLOG], ans[MAXN]; vector <int> a[MAXN], b[MAXN]; bool operator < (info a, info b) { return a.dist < b.dist || (a.dist == b.dist && a.pos < b.pos); } info operator + (info a, int b) { return (info) {a.pos, a.dist + b}; } void dfsI(int pos, int fa) { if (mark[pos]) { dmin[pos] = (info) {pos, 0}; dnim[pos] = (info) {0, INF}; } else { dmin[pos] = (info) {0, INF}; dnim[pos] = (info) {0, INF}; } for (unsigned i = 0; i < b[pos].size(); i++) { if (b[pos][i] == fa) continue; dfsI(b[pos][i], pos); int tmp = b[pos][i], len = depth[tmp] - depth[pos]; if (dmin[tmp] + len < dmin[pos]) { dnim[pos] = dmin[pos]; dmin[pos] = dmin[tmp] + len; } else { if (dmin[tmp] + len < dnim[pos]) dnim[pos] = dmin[tmp] + len; } } } int get(int x, int y) { for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[y][i]] > depth[x]) y = father[y][i]; return y; } void dfsII(int pos, int fa, int len) { if (pos == 1) { if (mark[pos]) up[pos] = (info) {pos, 0}; else up[pos] = (info) {0, INF}; if (up[pos] < dmin[pos]) home[pos] = up[pos]; else home[pos] = dmin[pos]; } else { info tmp; if (dmin[fa].pos == dmin[pos].pos) tmp = dnim[fa]; else tmp = dmin[fa]; if (mark[pos]) up[pos] = (info) {pos, 0}; else if (tmp < up[fa]) up[pos] = tmp + len; else up[pos] = up[fa] + len; if (up[pos] < dmin[pos]) home[pos] = up[pos]; else home[pos] = dmin[pos]; int now = pos; for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[now][i]] > depth[fa] && home[pos] + (depth[pos] - depth[father[now][i]]) < home[fa] + (depth[father[now][i]] - depth[fa])) now = father[now][i]; ans[home[fa].pos] += size[get(fa, now)] - size[now]; ans[home[pos].pos] += size[now] - size[pos]; } int delta = size[pos]; for (unsigned i = 0; i < b[pos].size(); i++) { if (b[pos][i] == fa) continue; delta -= size[get(pos, b[pos][i])]; dfsII(b[pos][i], pos, depth[b[pos][i]] - depth[pos]); } ans[home[pos].pos] += delta; } void init(int pos, int fa) { size[pos] = 1; dfn[pos] = ++timer; depth[pos] = depth[fa] + 1; father[pos][0] = fa; for (int i = 1; i < MAXLOG; i++) father[pos][i] = father[father[pos][i - 1]][i - 1]; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { init(a[pos][i], pos); size[pos] += size[a[pos][i]]; } } bool cmp(int x, int y) { return dfn[x] < dfn[y]; } int lca(int x, int y) { if (depth[x] < depth[y]) swap(x, y); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] >= depth[y]) x = father[x][i]; if (x == y) return x; for (int i = MAXLOG - 1; i >= 0; i--) if (father[x][i] != father[y][i]) { x = father[x][i]; y = father[y][i]; } return father[x][0]; } int main() { freopen("input.txt", "r", stdin); int n; scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); a[x].push_back(y); a[y].push_back(x); } init(1, 0); int T; scanf("%d", &T); while (T--) { int m; scanf("%d", &m); static int q[MAXN], p[MAXN]; for (int i = 1; i <= m; i++) { scanf("%d", &q[i]); p[i] = q[i]; mark[q[i]] = true; } sort(q + 1, q + m + 1, cmp); int cnt = 0, top = 0, s = 1; static int Stack[MAXN], used[MAXN]; if (q[1] == 1) s = 2; Stack[1] = 1; top = 1; used[1] = 1; cnt = 1; for (int i = s; i <= m; i++) { int Lca = lca(Stack[top], q[i]); if (Lca == Stack[top]) { Stack[++top] = q[i]; continue; } while (dfn[Lca] < dfn[Stack[top - 1]]) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; } if (Lca == Stack[top - 1]) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; Stack[++top] = q[i]; } else { b[Lca].push_back(Stack[top]); b[Stack[top]].push_back(Lca); used[++cnt] = Stack[top--]; Stack[++top] = Lca; Stack[++top] = q[i]; } } while (top >= 2) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; } dfsI(1, 0); dfsII(1, 0, 0); for (int i = 1; i <= cnt; i++) b[used[i]].clear(); for (int i = 1; i <= m; i++) { mark[q[i]] = false; printf("%d ", ans[p[i]]); ans[p[i]] = 0; } printf("\n"); } return 0; }