先简单说一下启发式合并吧
这道题我们可以遍历整棵树,并用一个数组ap(appear)记录每种颜色出现几次
但是每做完一棵子树就需要清空ap,以免对其兄弟造成影响。
而这样做它的祖先时就要把它重新搜一遍,浪费时间
但是我们发现,对于每个节点v,最后一棵子树是不用清空的,因为做完那棵子树后可 以把其结果直接加入v的答案中。
选哪棵子树呢?当然是所含节点最多的一棵咯,我们称之为“重儿子”
其实感觉这样快不了多少……但是它竟然是nlogn的!
觉得这里对于启发式合并的思路讲得好
做一做模板题:洛谷 Lomsat gelral
整体思路就是用cnt数组记录子树中颜色出现的次数,但是计算完一棵子树需要清空,最后计算重儿子的子树不需要清空
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100005;
int sz[N], big[N], col[N], cnt[N];//子树大小,重儿子,颜色,颜色出现的次数
ll ans[N];
vector<int> g[N];
ll ma, sum;
void dfs0(int x, int fa)//预处理,算出重儿子
{
sz[x] = 1, big[x] = 0;
for (auto y : g[x])
{
if (y == fa) continue;
dfs0(y, x);
if (sz[y] > sz[big[x]]) big[x] = y;
sz[x] += sz[y];
}
}
void change(int x, int fa, int v, int nt)
//v = 1递归计算整颗子树(除了重儿子)的影响并更新ma和sum
//v = -1消除子树对于cnt数组的影响
//nt是不能走的点,即最初调用时的x的重儿子,不用的时候置0
{
cnt[col[x]] += v;
if (cnt[col[x]] > ma) ma = cnt[col[x]], sum = col[x];
else if (cnt[col[x]] == ma) sum += col[x];
for (auto y: g[x])
{
if (y == fa || y == nt) continue;
change(y, x, v, nt);
}
}
void dfs(int x, int fa, bool keep)
{
for (auto y : g[x])//计算轻儿子的ans,并消除对cnt的影响
{
if (y == fa || y == big[x]) continue;
dfs(y, x, false);
}
if (big[x]) dfs(big[x], x, true);//计算重儿子的ans,保留对cnt的影响
//计算x的ans,不走重儿子,因为走过了并且保留了cnt
change(x, fa, 1, big[x]);
ans[x] = sum;
if (keep == false) //需要消除影响
{
change(x, fa, -1, 0);//整颗子树都要消除,注意nt不能填big[x](错了两次了呜
ma = sum = 0;//change完要置0,因为都要消除影响了,之后有需要重新计算的地方
}
}
int main()
{
int n;
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> col[i];
}
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}
dfs0(1, 0);
// for (int i = 1; i <= n; i++)
// cout << big[i] << ' ';
ma = sum = 0;
dfs(1, 0, true);
for (int i = 1; i <= n; i++)
cout << ans[i] << ' ';
return 0;
}
来到学启发式合并的初衷[蓝桥杯 2023 省 A] 颜色平衡树
就是加了一个桶来记录cnt的状况,以此来维护子树上的最大值和最小值,一棵子树最大值=最小值即是平衡的
(这种最值维护方法也是学到了,思路来自P9233
直接在上题的代码上改
#include <vector>
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 200005;
int col[N], cnt[N], sz[N], big[N], t[N];//t是记录cnt[col]的桶,方便算最大值最小值
vector<int> g[N];
void dfs0(int x, int fa)
{
sz[x] = 1;
big[x] = 0;
for (auto y: g[x])
{
if (y == fa) continue;
dfs0(y, x);
sz[x] += sz[y];
if (sz[y] > sz[big[x]]) big[x] = y;
}
}
int ma, mi, res = 0;
void init()
{
ma = 0, mi = N;
}
void change(int x, int fa, int v, int nt)
{
//更新桶
t[cnt[col[x]]]--;
cnt[col[x]] += v;
t[cnt[col[x]]]++;
//更新最大值最小值
if (cnt[col[x]] > ma) ma = cnt[col[x]];
if (cnt[col[x]] < mi) mi = cnt[col[x]];
if (t[ma] == 0) ma--;//cnt减的时候最大值可能退下来
if (t[mi] == 0) mi++;//cnt加的时候最小值可能涨
for (auto y : g[x])
{
if (y == fa || y == nt) continue;
change(y, x, v, nt);
}
}
void dfs(int x, int fa, bool keep)
{
for (auto y : g[x])
{
if (y == fa || y == big[x]) continue;
dfs(y, x, false);
}
if (big[x]) dfs(big[x], x, true);
change(x, fa, 1, big[x]);
if (ma == mi) res++;
if (keep == false)
{
change(x, fa, -1, 0);
init();
}
}
int main()
{
int n;
cin >> n;
for (int i = 1; i <= n; i++)
{
int y;
cin >> col[i] >> y;
g[i].push_back(y);
g[y].push_back(i);
}
dfs0(1, 0);
init();
dfs(1, 0, false);
cout << res << endl;
return 0;
}
先这样