【题目链接】
【思路要点】
- 首先考虑一种 O ( N ) O(N) O(N) 处理一个询问的方法,我们以 R R R 为根,进行 d f s dfs dfs 。
- 我们称一个存在关键叶子的子树为“满的”,当且仅当该子树中叶子结点的数量等于关键叶子的数量,称一个子树为“不满的”,当且仅当该子树中存在关键叶子,并且它不是满的。
- 若一个点 x x x 存在三个或以上不满的子树,那么显然我们不可能将其安排至合法,因此,我们称点 x x x 是不合法的。
- 若一个点 x x x 子树内不存在不合法的点,并且 x x x 存在恰好两个不满的子树,那么我们必须将这两个子树安排在开头和结尾,即将其中一个的关键叶子安排为最后访问的一系列叶子,将另一个的关键叶子安排为最先访问的一系列叶子。如此一来,我们不可能将点 x x x 子树内的关键叶子安排为最先或是最后访问,因此我们要求所有的关键叶子都必须在点 x x x 的子树中出现,否则,我们同样认为 x x x 是不合法的。
- 若一个点 x x x 子树内不存在不合法的点,并且 x x x 存在不足两个不满的子树,我们一定可以将点 x x x 子树内的关键叶子安排为最先或是最后访问,因此这样的点 x x x 始终合法。
- 答案为 N O NO NO 当且仅当存在不合法的点。
- 上述计算过程可以通过构建虚树进行优化,代码细节不再赘述。
- 时间复杂度 O ( N L o g N + ∑ K i L o g N ) O(NLogN+\sum K_iLogN) O(NLogN+∑KiLogN) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 5e5 + 5; const int MAXLOG = 21; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } bool ans, mark[MAXN]; int timer, size[MAXN], depth[MAXN], dfn[MAXN], rit[MAXN]; int root, m, n, t, all, father[MAXN][MAXLOG], sum[MAXN], cnt[MAXN]; vector <int> a[MAXN], b[MAXN]; int lca(int x, int y) { if (depth[x] < depth[y]) swap(x, y); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] >= depth[y]) x = father[x][i]; if (x == y) return x; for (int i = MAXLOG - 1; i >= 0; i--) if (father[x][i] != father[y][i]) { x = father[x][i]; y = father[y][i]; } return father[x][0]; } int get(int x, int y) { int z = lca(x, y); if (x != z) return father[x][0]; for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[y][i]] > depth[x]) y = father[y][i]; return y; } int calcsize(int x) { if (dfn[root] >= dfn[x] && dfn[root] <= rit[x]) return sum[1] - sum[get(x, root)]; else return sum[x]; } void work(int pos, int fa) { cnt[pos] = mark[pos]; int notfull = 0; for (auto x : b[pos]) if (x != fa) { work(x, pos); int y = get(pos, x); notfull += cnt[x] && (calcsize(y) != cnt[x]); cnt[pos] += cnt[x]; } if (notfull >= 3) ans = false; else if (notfull == 2) ans &= cnt[pos] == all; } void dfs(int pos, int fa) { size[pos] = 1; sum[pos] = a[pos].size() == 1; dfn[pos] = ++timer; depth[pos] = depth[fa] + 1; father[pos][0] = fa; for (int i = 1; i < MAXLOG; i++) father[pos][i] = father[father[pos][i - 1]][i - 1]; for (unsigned i = 0; i < a[pos].size(); i++) if (a[pos][i] != fa) { dfs(a[pos][i], pos); sum[pos] += sum[a[pos][i]]; size[pos] += size[a[pos][i]]; } rit[pos] = timer; } bool cmp(int x, int y) { return dfn[x] < dfn[y]; } int main() { read(n), read(t); for (int i = 1; i <= n - 1; i++) { int x, y; read(x), read(y); a[x].push_back(y); a[y].push_back(x); } dfs(1, 0); while (t--) { read(root), read(m); static int q[MAXN]; for (int i = 1; i <= m; i++) { read(q[i]); mark[q[i]] = true; } q[++m] = root; sort(q + 1, q + m + 1, cmp); int cnt = 0, top = 0, s = 1; static int Stack[MAXN], used[MAXN]; if (q[1] == 1) s = 2; Stack[1] = 1; top = 1; used[1] = 1; cnt = 1; for (int i = s; i <= m; i++) { int Lca = lca(Stack[top], q[i]); if (Lca == Stack[top]) { Stack[++top] = q[i]; continue; } while (dfn[Lca] < dfn[Stack[top - 1]]) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; } if (Lca == Stack[top - 1]) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; Stack[++top] = q[i]; } else { b[Lca].push_back(Stack[top]); b[Stack[top]].push_back(Lca); used[++cnt] = Stack[top--]; Stack[++top] = Lca; Stack[++top] = q[i]; } } while (top >= 2) { b[Stack[top - 1]].push_back(Stack[top]); b[Stack[top]].push_back(Stack[top - 1]); used[++cnt] = Stack[top--]; } ans = true, all = m - 1; work(root, 0); if (ans) puts("YES"); else puts("NO"); for (int i = 1; i <= m; i++) mark[q[i]] = false; for (int i = 1; i <= cnt; i++) b[used[i]].clear(); } return 0; }