Address
Solution
- 先建出虚树。
- 记 id[x] i d [ x ] 表示距离点 x x 最近的关键点编号, 表示距离点 x x 最近的关键点距离。
- 那么对于一个点 ,最近的关键点既能从它的父节点转移过来,也能从它的子节点转移过来,因此做正反两遍转移。
- 得到这些以后,我们考虑树上所有点的关键点归属情况。
对于虚树上的一条边 y→x y → x (不包括点 x x ,避免重复计数):
- 记 为以 x x 为根的子树大小。
- 若 为虚树的根节点,则点
y
y
上方的 个点必然属于
id[y]
i
d
[
y
]
管辖(如下图红色节点)。
- 考虑用倍增找到该边上与点
y
y
距离为1的点 ,则有
sze[y]−sze[w]
s
z
e
[
y
]
−
s
z
e
[
w
]
个点必然属于
id[y]
i
d
[
y
]
管辖(如下图蓝色节点)。
- 若
id[y]=id[x]
i
d
[
y
]
=
i
d
[
x
]
,则剩余的
sze[w]−sze[x]
s
z
e
[
w
]
−
s
z
e
[
x
]
个点也属于
id[y]
i
d
[
y
]
管辖(如下图绿色节点)。
- 若
id[y]≠id[x]
i
d
[
y
]
≠
i
d
[
x
]
,则可以用倍增找到一个该边上的分界点
z
z
,使得 点上方属于
id[y]
i
d
[
y
]
管辖,
z
z
点及 点下方属于
id[x]
i
d
[
x
]
管辖(如下图黄色与橙色节点,注意距离相等的特殊情况)。
时间复杂度 O(nlogn) O ( n log n ) ,瓶颈在排序。
Code
#include <iostream> #include <cstdio> #include <cctype> #include <algorithm> #include <cstring> using namespace std; namespace inout { const int S = 1 << 20; char frd[S], *ihed = frd + S; const char *ital = ihed; inline char inChar() { if (ihed == ital) fread(frd, 1, S, stdin), ihed = frd; return *ihed++; } inline int get() { char ch; int res = 0; bool flag = false; while (!isdigit(ch = inChar()) && ch != '-'); (ch == '-' ? flag = true : res = ch ^ 48); while (isdigit(ch = inChar())) res = res * 10 + ch - 48; return flag ? -res : res; } char fwt[S], *ohed = fwt; const char *otal = ohed + S; inline void outChar(char ch) { if (ohed == otal) fwrite(fwt, 1, S, stdout), ohed = fwt; *ohed++ = ch; } inline void put(int x) { if (x > 9) put(x / 10); outChar(x % 10 + 48); } }; using namespace inout; const int Maxn = 0x3f3f3f3f; const int N = 3e5 + 5; int dfn[N], fa[N][20], sze[N], dep[N], del[N]; int vir[N], virs[N], par[N], Ans[N], stk[N]; int n, m, q, tis, top, im; struct Edge { int to; Edge *nxt; }p[N << 1], *T = p, *lst[N]; struct point { int dis, id; #define d(x) tr[x].dis #define t(x) tr[x].id point() {} point(const int &D, const int &I): dis(D), id(I) {} inline bool operator < (const point &x) const { return dis == x.dis ? id < x.id : dis < x.dis; } }tr[N]; inline void Link(int x, int y) { (++T)->nxt = lst[x]; lst[x] = T; T->to = y; (++T)->nxt = lst[y]; lst[y] = T; T->to = x; } inline void initLCA(int x, int Fa) { dfn[x] = ++tis; sze[x] = 1; dep[x] = dep[Fa] + 1; for (int i = 0; i < 18; ++i) fa[x][i + 1] = fa[fa[x][i]][i]; for (Edge *e = lst[x]; e; e = e->nxt) { int y = e->to; if (y == Fa) continue; fa[y][0] = x; initLCA(y, x); sze[x] += sze[y]; } } inline int queryLCA(int x, int y) { if (dep[x] < dep[y]) swap(x, y); for (int i = 18; i >= 0; --i) { if (dep[fa[x][i]] >= dep[y]) x = fa[x][i]; if (x == y) return x; } for (int i = 18; i >= 0; --i) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } inline int Jump(int x, int d) { for (int i = 0; d; ++i, d >>= 1) if (d & 1) x = fa[x][i]; return x; } inline bool cmp(const int &x, const int &y) {return dfn[x] < dfn[y];} inline void auxTree() { top = 0; sort(vir + 1, vir + m + 1, cmp); for (int i = 1; i <= im; ++i) { int x = vir[i]; if (!top) { stk[++top] = x; par[x] = 0; continue; } int lca = queryLCA(x, stk[top]); for (; top && dep[stk[top]] > dep[lca]; --top) if (dep[stk[top - 1]] <= dep[lca]) par[stk[top]] = lca; if (lca != stk[top]) { if (dep[stk[top]] == dep[lca]) --top; par[lca] = stk[top]; stk[++top] = vir[++m] = lca; tr[lca] = point(Maxn, 0); } par[x] = lca; stk[++top] = x; } sort(vir + 1, vir + m + 1, cmp); } int main() { n = get(); for (int i = 1; i < n; ++i) Link(get(), get()); initLCA(1, 0); q = get(); while (q--) { im = m = get(); int x; for (int i = 1; i <= m; ++i) { vir[i] = virs[i] = x = get(); tr[x] = point(0, x); Ans[x] = 0; } auxTree(); for (int i = m; i >= 2; --i) { int x = vir[i], y = par[x]; del[x] = dep[x] - dep[y]; point tmp = point(del[x] + d(x), t(x)); if (tmp < tr[y]) tr[y] = tmp; } for (int i = 2; i <= m; ++i) { int x = vir[i], y = par[x]; point tmp = point(del[x] + d(y), t(y)); if (tmp < tr[x]) tr[x] = tmp; } for (int i = 1; i <= m; ++i) { int x = vir[i], y = par[x]; Ans[t(x)] += sze[x]; if (i == 1) { Ans[t(x)] += n - sze[x]; continue; } int w = Jump(x, del[x] - 1), totS = sze[w] - sze[x]; Ans[t(y)] -= sze[w]; if (t(x) == t(y)) Ans[t(x)] += totS; else { int z = d(x) - d(y) + dep[x] + dep[y] + 1 >> 1; if (t(y) < t(x) && d(y) + z - dep[y] == d(x) - z + dep[x]) ++z; z = sze[Jump(x, dep[x] - z)] - sze[x]; Ans[t(x)] += z; Ans[t(y)] += totS - z; } } for (int i = 1; i <= im; ++i) put(Ans[virs[i]]), outChar(' '); outChar('\n'); } fwrite(fwt, 1, ohed - fwt, stdout); return 0; }