换根动态规划
树形 DP 中的换根 DP 问题又被称为二次扫描,通常不会指定根结点,并且根结点的变化会对一些值,例如子结点深度和、点权和等产生影响。
通常需要两次 DFS,第一次 DFS 预处理诸如深度,点权和之类的信息,在第二次 DFS 开始运行换根动态规划。
1. 不带权版本
- 此题是维护不同根节点的子树高度
我在此就通俗易懂的说说个人理解,理解不到位之处还望指出。
所谓换根DP,就是基于原有的状态,通过相邻节点进行转换后,现有的状态仅仅只需要进行微小的变动即可达到完美相邻状态间的切换。
而对应在这题,要计算不同根下的树高,首先需要计算出原始状态,这里用 0 0 0 号节点为根来计算。那么之后进行换根操作,理应从 0 0 0 号节点开始。
我们定义:
- u u u:当前节点
- j j j : u u u 的某一相邻节点
- h [ ] h[] h[]:维护当前某个节点为根下的子树高状态
- f [ ] f[] f[]:计算得到当前节点 i i i 为根下的 i i i 的树高
那么如果当前 u u u 为根节点,现在要进行换根操作到 j j j 节点为根。明确 j j j 是 u u u 的其中一个相邻节点,所以 j j j 换为根后, u u u 的子树高状态应当扣除掉之前维护中包含的 j j j 子树高的状态。而怎么样才会影响到 u u u 的状态的改变呢?无非是两种情况: j j j 树高是 u u u 树高的最大值,除去后, u u u 树高应变为次大值 + 1 +1 +1;反之, u u u 树高仍是相邻节点的最大值树高 + 1 +1 +1。
那么考虑完了两个根节点之间状态的转变,根相邻节点的状态应该如何变化?在这仅以 u u u 节点的相邻节点考虑,对称状态为 j j j 的相邻节点。由于我们维护的状态是在 h h h 数组中, h [ i ] h[i] h[i] 仅表示当前某个节点为根下, i i i 的子树高。所以当根由 u u u 向 j j j 转变后,实际上对 u u u 的除 j j j 的相邻节点,这些点的树高状态并不会改变,读者可以自己思考。
上述就实现完成了换根后当前 h h h 数组维护的状态的转变,如下代码所示:
for (auto j: e[u]) {
if (h[j] == mx1) h[u] = mx2 + 1; // j点是最大的树高,u的树高是次高+1
else h[u] = mx1 + 1; // j点不是最大树高,u的树高是最高+1
dfs2(dfs2, j);
}
如何更新最后想要求得的 f f f 数组呢?
f [ i ] f[i] f[i]:表示 i i i 为根下, i i i 的子树高。
那么在 d f s 2 dfs2 dfs2 ,第二次DFS中,执行换根 u − j u-j u−j 前,当前根 u u u 的 f [ u ] f[u] f[u] 直接更新即可:
f[u] = mx1 + 1;
此题全部代码如下:
class Solution {
public:
vector<int> findMinHeightTrees(int n, vector<vector<int>>& edges) {
vector<int> e[n + 1]; // 构建图的邻接表
for (auto edge: edges) {
int x = edge[0], y = edge[1];
e[x].push_back(y);
e[y].push_back(x);
} // 建图
int h[n + 1]; memset(h, 0, sizeof h); // 0为根下的原树高
auto dfs1 = [&](auto &&dfs1, int u, int f) -> int {
int s = 0;
for (auto j: e[u]) {
if (j != f) s = max(s, dfs1(dfs1, j, u));
}
h[u] = s + 1;
return h[u];
};
dfs1(dfs1, 0, 0); // dfs1跑出h数组
int f[n + 1]; memset(f, 0, sizeof f); // 维护以i为根的树高f[i]
auto dfs2 = [&](auto &&dfs2, int u) -> void {
int mx1 = 0, mx2 = 0; // 最大、次大
for (auto j: e[u]) { // 跑出u的邻接点的mx1,mx2
if (mx1 < h[j]) mx2 = mx1, mx1 = h[j];
else if (mx2 < h[j]) mx2 = h[j];
}
f[u] = mx1 + 1; // 当前u为根的树高 = 最大子树高+1
for (auto j: e[u]) { // 换根操作!-> 维护h数组的变动
if (f[j] != 0) continue; // 已经算过就不在计算了
if (h[j] == mx1) h[u] = mx2 + 1; // j点是最大的树高,u的树高是次高+1
else h[u] = mx1 + 1; // j点不是最大树高,u的树高是最高+1
dfs2(dfs2, j);
}
};
dfs2(dfs2, 0); // 跑出f数组
int mx = n;
vector<int> res;
for (int i = 0; i < n; i++) {
if (mx > f[i]) mx = f[i], res.clear();
if (mx == f[i]) res.push_back(i);
}
return res;
}
};
2. 带权版本
- 此题与上题完全一致,套用换根模板即可。
class Solution {
public:
long long maxOutput(int n, vector<vector<int>>& edges, vector<int>& price) {
vector<int> e[n + 1];
for (auto edge: edges) {
int x = edge[0], y = edge[1];
e[x].push_back(y);
e[y].push_back(x);
}
long long hx[n + 1], hn[n + 1]; memset(hx, 0, sizeof hx); memset(hn, 0, sizeof hn);
auto dfs1 = [&](auto &&dfs1, int u, int f) -> void {
long long mx = 0, mn = 0;
for (auto j: e[u]) {
if (j == f) continue;
dfs1(dfs1, j, u);
mx = max(mx, hx[j]);
mn = min(mn, hn[j]);
}
hx[u] = mx + price[u];
hn[u] = mn + price[u];
};
dfs1(dfs1, 0, -1);
long long fx[n + 1], fn[n + 1]; memset(fx, 0, sizeof fx); memset(fn, 0, sizeof fn);
auto dfs2 = [&](auto &&dfs2, int u) -> void {
long long mx1 = 0, mx2 = 0;
long long mn1 = 0, mn2 = 0;
for (auto j: e[u]) {
if (mx1 < hx[j]) mx2 = mx1, mx1 = hx[j];
else if (mx2 < hx[j]) mx2 = hx[j];
if (mn1 > hn[j]) mn2 = mn1, mn1 = hn[j];
else if (mn2 > hn[j]) mn2 = hn[j];
}
fx[u] = mx1 + price[u];
fn[u] = mn1 + price[u];
for (auto j: e[u]) {
if (fx[j] != 0 && fn[j] != 0) continue;
hx[u] = hx[j] == mx1 ? mx2 : mx1; hx[u] += price[u];
hn[u] = hn[j] == mn1 ? mn2 : mn1; hn[u] += price[u];
dfs2(dfs2, j);
}
};
dfs2(dfs2, 0);
long long res = 0;
for (int i = 0; i < n; i++) res = max(res, fx[i] - /*fn[i]*/ price[i]);
return res;
}
};
例题:统计可能的树根数目
class Solution {
public:
int rootCount(vector<vector<int>>& edges, vector<vector<int>>& guesses, int k) {
int n = edges.size();
vector<int> e[n + 1];
for (auto &edge: edges) {
e[edge[0]].push_back(edge[1]);
e[edge[1]].push_back(edge[0]);
}
map<pair<int, int>, int> mp;
for (auto &x: guesses) {
mp[{x[0], x[1]}] = 1;
}
int res = 0;
function<void(int, int)> dfs1 = [&](int u, int f) -> void {
for (auto &j: e[u]) {
if (j == f) continue;
if (mp.count({u, j})) res++;
dfs1(j, u);
}
};
dfs1(0, -1);
int ans = 0;
function<void(int, int)> dfs2 = [&](int u, int f) -> void {
if (res >= k) ans++;
for (auto &j: e[u]) {
if (j == f) continue;
int tmp = 0;
if (mp.count({u, j})) tmp--;
if (mp.count({j, u})) tmp++;
res += tmp;
dfs2(j, u);
res -= tmp;
}
};
dfs2(0, -1);
return ans;
}
};