CodeForces 1324F Maximum White Subtree
题目大意
给定一棵树,每个节点为白色或黑色,现对每个节点 u u u,选出任意大小的包含该节点的连通图,使得图上白色点数目与黑色的数目的差距最大。
分析
不妨记白点权值为1,黑点权值为-1,则问题转化为选择包含 u u u 的一个连通子图,使得其点权和最大。
显然我们可以注意到这是一个典型的树形DP。对于一个节点 u u u ,不妨将它的答案分成两部分来考虑:一部分为 u u u 的子树,另一部分为去除 u u u 的子树后但包含 u u u 的剩下一部分。
先考虑 u u u 的子树的答案,这是一个非常简单的树形DP,记 f ( u ) f(u) f(u) 为 u u u 的子树的最大点权和,注意到若子节点答案为负则可直接舍去该子节点,则容易得出状态转移方程:
f ( u ) = v a l ( u ) + ∑ max ( f ( v ) , 0 ) f(u) = val(u) + \sum \max(f(v), 0) f(u)=val(u)+∑max(f(v),0)
其中 v a l ( u ) val(u) val(u) 为 u u u 自身点权,而 v v v 为 u u u 的子节点。
然后再考虑如何计算第二部分(记为 g ( u ) g(u) g(u))的答案。
首先可以看出 g ( u ) g(u) g(u) 必然由 u u u 的父亲来计算,也就是说, g ( u ) g(u) g(u) 的答案要不就只包含 u u u,要不就包含 u u u 的父亲和它的兄弟节点。所以我们可以得出这样的转移方程:
g ( u ) = v a l ( u ) + max ( 0 , g ( f a ) + ∑ v ≠ u max ( f ( v ) , 0 ) ) g(u) = val(u) + \max(0, g(fa) + \sum^{v \ne u} \max(f(v), 0)) g(u)=val(u)+max(0,g(fa)+∑v=umax(f(v),0))
其中 f a fa fa 为 u u u 的父节点, v v v 为 f a fa fa 的子节点。
考虑到直接计算这个式子的 Σ \Sigma Σ 会使得复杂度到 O ( n 2 ) O(n^2) O(n2),我们必须考虑优化。注意到上面计算 f ( u ) f(u) f(u) 的式子,我们将它改成计算 f ( f a ) f(fa) f(fa) 的式子:
f ( f a ) = v a l ( f a ) + max ( f ( u ) , 0 ) + ∑ v ≠ u max ( f ( v ) , 0 ) f(fa) = val(fa) + \max(f(u), 0) + \sum^{v \ne u} \max(f(v), 0) f(fa)=val(fa)+max(f(u),0)+∑v=umax(f(v),0)
也即 ∑ v ≠ u max ( f ( v ) , 0 ) = f ( f a ) − v a l ( f a ) − max ( f ( u ) , 0 ) \sum^{v \ne u} \max(f(v), 0) = f(fa) - val(fa) - \max(f(u), 0) ∑v=umax(f(v),0)=f(fa)−val(fa)−max(f(u),0)
因而第二个状态转移方程可化为
g ( u ) = v a l ( u ) + max ( 0 , g ( f a ) + f ( f a ) − v a l ( f a ) − max ( f ( u ) , 0 ) ) g(u) = val(u) + \max(0, g(fa) + f(fa) - val(fa) - \max(f(u), 0)) g(u)=val(u)+max(0,g(fa)+f(fa)−val(fa)−max(f(u),0))
故可 O ( n ) O(n) O(n) 计算出 g ( u ) g(u) g(u)。而在合并答案时由于 v a l ( u ) val(u) val(u) 被计算了两次,合并需减去一次。
参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int MaxN = (int)2e5;
int N, val[MaxN + 5];
vector<int> G[MaxN + 5];
inline void addedge(int u, int v) {
G[u].push_back(v), G[v].push_back(u);
}
int f[MaxN + 5], g[MaxN + 5];
void DFS1(int u, int fa) {
f[u] = val[u];
for(auto v : G[u]) {
if(v == fa) continue;
DFS1(v, u);
f[u] += max(0, f[v]);
}
}
void DFS2(int u, int fa) {
int t = f[fa] - val[fa] - max(0, f[u]);
g[u] = val[u] + max(0, g[fa] + t);
for(auto v : G[u]) {
if(v == fa) continue;
DFS2(v, u);
}
}
int main() {
scanf("%d", &N);
for(int i = 1; i <= N; i++) {
int col;
scanf("%d", &col);
val[i] = (col ? 1 : -1);
}
for(int i = 1; i < N; i++) {
int u, v;
scanf("%d %d", &u, &v);
addedge(u, v);
}
DFS1(1, -1);
DFS2(1, 0);
for(int i = 1; i <= N; i++) {
int x = f[i] + max(0, g[i] - val[i]);
printf("%d ", x);
}
return 0;
}