捋一捋
analysis
- 考虑遍历每一个节点,以每个节点作为 l c a lca lca 思考。
- 当前节点为 l c a lca lca 那么要想答案更大肯定是从不同子树(不同子树满足 l c a lca lca)中各选择一个节点到 l c a lca lca 不同颜色最多,假设 c i ci ci 为每个节点到当前节点不同颜色的数量,那么就要选择每个子树中最大的 c i ci ci ,然后选出最大值和次大值。
- 考虑使用 d f s dfs dfs ,然后在回溯的过程中更新节点和答案。
problems
- 如何找出子树的最大 c i ci ci ?
- 如何更新一个节点的贡献?
- 在回溯的过程中遇见颜色一致的节点怎样做到不重不漏?
solutions
- 对于最大 c i ci ci 很容易想到 R M Q RMQ RMQ 的做法,在这里可以采用 d f s dfs dfs 序结合一个数据结构。
- 如何更新一个节点的贡献,因为 c i ci ci 是向上到每一个节点的不同颜色的数量,所以当当前节点更新答案过后就根据我们 d f s dfs dfs 的出来的区间进行区间加 1 1 1 。(子节点一定会经过这个节点)
- 对于回溯,因为是回溯,所以很容易想到更新是一个自下而上的过程。所以我们应该考虑的是祖先节点与孙子节点(当然也可能是父节点与儿子节点)颜色一致的情况。上面对于节点的更新遇见孙子节点一致的话,我们不妨用一个 s e t set set 装入每个颜色对应的节点的 d f s dfs dfs 序。孙子节点的 d f s dfs dfs 序一定大于祖先节点。所以在更新答案之前我们可以先将颜色一致的孙子节点进行区间减,然后在更新答案。
- 综上所述数据结构可食用 线段树
- 每个节点至多进行一次区间加,区间减和 s e t set set 的插入删除。时间复杂度 O ( n l o g n ) O(nlogn) O(nlogn) 。
Think Twice, Code once
#include <bits/stdc++.h>
#define il inline
#define get getchar
#define put putchar
#define is isdigit
#define int long long
#define dfor(i, a, b) for(int i = a; i <= b; ++i)
#define dforr(i, a, b) for(int i = a; i >= b; --i)
#define dforn(i, a, b) for(int i = a; i <= b; ++i, put(10))
#define mem(a, b, c) memset(a, b, c)
#define memc(a, b) memcpy(a, b, sizeof (a))
#define pr 114514191981
#define gg(a) cout << a, put(32)
#define INF 0x7fffffff
#define tf(x) cout << '\n' << "-> " << x << " <-" << '\n';
#define endl '\n'
#define ls i << 1
#define rs i << 1 | 1
#define la(r) tr[r].ch[0]
#define ra(r) tr[r].ch[1]
#define lowbit(x) (x & -x)
#define ct cin.tie(nullptr),ios_base::sync_with_stdio(false)
using namespace std;
typedef unsigned int ull;
typedef pair<int, int> pii;
int read(void) {
int x = 0, f = 1; char c = get();
while(!is(c)) (f = c == 45? -1: 1), c = get();
while(is(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = get();
return x * f;
}
void write(int x) {
if (x < 0) x = -x, put(45);
if (x > 9) write(x / 10);
put((x % 10) ^ 48);
}
#define writeln(a) write(a), put(10)
#define writesp(a) write(a), put(32)
#define writessp(a) put(32), write(a)
const int N = 3e5 + 10, M = 2e5 + 10, SN = 1e3 + 10, mod = 1e9 + 9, MOD = 998244353;
int tot, ans, a[N], ed[N], re[N], dfn[N], head[N];
vector<pii> e(N);
set<int, greater<int>> s[N];
struct p {
int l, r, Max, tag;
}tr[N << 2];
void build(int i, int l, int r) {
tr[i] = {l, r, 0, 0};
if (l == r) return ;
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
}
void pushup(int i) {
tr[i].Max = max(tr[ls].Max, tr[rs].Max);
}
void pushdown(int i) {
if (tr[i].tag) {
tr[ls].Max += tr[i].tag;
tr[rs].Max += tr[i].tag;
tr[ls].tag += tr[i].tag, tr[rs].tag += tr[i].tag;
tr[i].tag = 0;
}
}
void modify(int i, int l, int r, int v) {
if (l <= tr[i].l && tr[i].r <= r) {
tr[i].Max += v;
tr[i].tag += v;
return ;
}
pushdown(i);
if (l <= tr[ls].r) modify(ls, l, r, v);
if (r >= tr[rs].l) modify(rs, l, r, v);
pushup(i);
}
int query(int i, int l, int r) {
if (l <= tr[i].l && tr[i].r <= r) return tr[i].Max;
pushdown(i);
int res = 0;
if (l <= tr[ls].r) res = max(res, query(ls, l, r));
if (r >= tr[rs].l) res = max(res, query(rs, l, r));
return res;
}
void dfs1(int u) {
dfn[u] = ++tot, re[tot] = u;
for (int i = head[u]; i; i = e[i].second) {
int v = e[i].first;
dfs1(v);
}
ed[u] = tot;
}
void dfs2(int u) {
int Max1 = 0, Max2 = 0;
for (int i = head[u]; i; i = e[i].second) {
int v = e[i].first;
dfs2(v);
while (!s[a[u]].empty() && *s[a[u]].begin() > dfn[u]) {
int l = *s[a[u]].begin(), r = ed[re[*s[a[u]].begin()]];
modify(1, l, r, -1);
s[a[u]].erase(s[a[u]].begin());
}
int t = query(1, dfn[v], ed[v]);
if (t > Max1) Max2 = Max1, Max1 = t;
else if (t > Max2) Max2 = t;
}
// while (!s[a[u]].empty() && *s[a[u]].begin() > dfn[u]) {
// int l = *s[a[u]].begin(), r = ed[re[*s[a[u]].begin()]];
// modify(1, l, r, -1);
// }
// for (int i = head[u]; i; i = e[i].second) {
// int v = e[i].first;
// int t = query(1, dfn[v], ed[v]);
// if (t > Max1) Max2 = Max1, Max1 = t;
// else if (t > Max2) Max2 = t;
// }
// cout << "Max1: " << Max1 << " Max2: " << Max2 << endl;
ans = max((Max1 + 1) * (Max2 + 1), ans);
modify(1, dfn[u], ed[u], 1);
s[a[u]].insert(dfn[u]);
}
signed main() {
int cnt = 0;
auto add = [] (int u, int v, int &cnt) {
e[++cnt] = {v, head[u]}, head[u] = cnt;
};
auto init = [] (int n ,int &cnt) {
ans = 1, cnt = tot = 0;
memset(tr + 1, 0, sizeof(p) * 4 * n);
for (int i = 1; i <= n; ++i) s[i].clear(), head[i] = 0;
};
int T = 1;
T = read();
while (T--) {
int n = read();
init(n, cnt);
for (int i = 2; i <= n; ++i) {
int pi = read();
add(pi, i, cnt);
}
for (int i = 1; i <= n; ++i) a[i] = read();
build(1, 1, n);
dfs1(1);
dfs2(1);
writeln(ans);
}
return 0;
}
//12
//1 1 1 2 2 3 4 4 7 7 6
//11 2 1 11 12 8 5 8 8 5 11 7