Description
考虑对括号序列进行哈希。
二分答案 k k k。考虑一个点的 k − k- k−子树就是它的子树去掉距离它大于 k k k的节点,对应到括号序列上就是一个大区间减去若干个小区间。把每个点挂在它的 k + 1 k+1 k+1级祖先上,然后计算即可。
Solution
考虑对括号序列进行哈希。
二分答案 k k k。考虑一个点的 k − k- k−子树就是它的子树去掉距离它大于 k k k的节点,对应到括号序列上就是一个大区间减去若干个小区间。把每个点挂在它的 k + 1 k+1 k+1级祖先上,然后计算即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
const int mod[2] = {998244353, 1004535809}, base[2] = {5007449, 5669};
const int maxn = 100005;
int n, son[maxn];
struct edge
{
int to, next;
} e[maxn * 2];
int h[maxn], tot;
int dfn[maxn], low[maxn], s[maxn * 2], len[maxn], Time;
int pw[maxn * 2][2], hsh[maxn * 2][2];
int stk[maxn], top;
vector<int> vec[maxn];
unordered_map<int, int> g[2];
inline int gi()
{
char c = getchar();
while (c < '0' || c > '9') c = getchar();
int sum = 0;
while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
inline void add(int u, int v)
{
e[++tot] = (edge) {v, h[u]};
h[u] = tot;
}
void dfs(int u)
{
s[dfn[u] = ++Time] = 1;
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
dfs(v), len[u] = max(len[u], len[v]);
++len[u]; low[u] = ++Time;
}
void dfs(int u, int k)
{
stk[++top] = u;
if (top > k + 1) vec[stk[top - k - 1]].push_back(u);
for (int i = h[u], v; v = e[i].to, i; i = e[i].next)
dfs(v, k);
--top;
}
void add(int *val, int l, int r)
{
val[0] = (hsh[r][0] + (lint)(val[0] - hsh[l - 1][0] + mod[0]) * pw[r - l + 1][0]) % mod[0];
val[1] = (hsh[r][1] + (lint)(val[1] - hsh[l - 1][1] + mod[1]) * pw[r - l + 1][1]) % mod[1];
}
bool check(int k)
{
for (int i = 1; i <= n; ++i) vec[i].clear();
dfs(1, k);
g[0].clear(); g[1].clear();
for (int i = 1; i <= n; ++i) {
if (len[i] <= k) continue;
int val[2] = {0, 0}, l = dfn[i], r;
for (int v : vec[i]) {
r = dfn[v] - 1;
if (l <= r) add(val, l, r);
l = low[v] + 1;
}
add(val, l, low[i]);
if (g[0].count(val[0]) && g[1].count(val[1])) return 1;
g[0][val[0]] = 1; g[1][val[1]] = 1;
}
return 0;
}
int main()
{
n = gi();
for (int x, i = 1; i <= n; ++i) {
x = gi();
for (int j = 1; j <= x; ++j) son[j] = gi();
for (int j = x; j >= 1; --j) add(i, son[j]);
}
dfs(1);
pw[0][0] = pw[0][1] = 1;
for (int i = 1; i <= Time; ++i) {
for (int j = 0; j < 2; ++j) {
pw[i][j] = (lint)pw[i - 1][j] * base[j] % mod[j];
hsh[i][j] = ((lint)hsh[i - 1][j] * base[j] + (s[i] ? 19 : 23)) % mod[j];
}
}
int l = 0, r = len[1], mid;
while (l < r) {
mid = (l + r + 1) >> 1;
if (check(mid)) l = mid;
else r = mid - 1;
}
printf("%d\n", l);
return 0;
}