Problem Description
给出一棵 n n n 个节点的有边权的无根树,每个点为白色或黑色,每个点都能翻转颜色,但是需要花费一定价格。
我们定义一棵树收益为: ∑ x ∈ V 1 ∑ y ∈ V 2 v a l ( x , y ) \sum\limits_{x \in V_1} \sum\limits_{y \in V2} val(x,y) x∈V1∑y∈V2∑val(x,y) , V 1 V_1 V1 表示白点集合, V 2 V_2 V2 表示黑点集合, v a l ( x , y ) val(x,y) val(x,y) 表示 x x x 到 y y y 最短路径上的最大边权。
求最大收益。
Input
第一行输入一个整数 n n n ,表示节点个数。
第二行输入 n n n 个整数 a i a_i ai ( 0 ≤ a i ≤ 1 ) (0 \le a_i \le 1) (0≤ai≤1) ,表示第 i i i 个节点的颜色, 0 0 0 表示白色, 1 1 1 表示黑色。
第三行输入 n n n 个整数 c o s t i cost_i costi ( 0 ≤ c o s t i ≤ 1 0 9 ) (0 \le cost_i \le 10^9) (0≤costi≤109) ,表示第 i i i 个节点翻转所需的费用。
接下来 n − 1 n-1 n−1 行每行输入三个整数 u i , v i , w i u_i,v_i,w_i ui,vi,wi ( 1 ≤ u i , v i , w i ≤ n ) (1 \le u_i,v_i,w_i \le n) (1≤ui,vi,wi≤n) ,表示 u i u_i ui 和 v i v_i vi 之间有一条边权为 w i w_i wi 的边。
Output
输出最大收益。
Solution
观察到 v a l ( x , y ) val(x,y) val(x,y) 表示 x x x 到 y y y 最短路径上的最大边权,我们很显然地就能联想到 K r u s k a l Kruskal Kruskal 重构树的构造和性质,两个点的路径最大边权即是重构树上两点的 L C A LCA LCA ,且构造重构树时是将两个集合合并,合并时的对应边权是两个集合路径上的最大边权。
这样我们可以将按边权为关键值排序后,从小到大枚举每条边的对答案的贡献。
那么问题来了如何求这个贡献?
此时我们便考虑
d
p
dp
dp ,设计
d
p
u
,
i
dp_{u,i}
dpu,i 表示当前集合
u
u
u 中白点个数为
i
i
i 的最大值。当两个集合合并时,我们先去枚举总的白点个数,然后再枚举其中一个集合的白点个数,最终我们可以得到转移方程为:
d
p
u
,
i
=
m
a
x
(
d
p
u
,
i
,
d
p
u
,
l
+
d
p
u
,
r
+
l
∗
(
m
−
r
)
+
(
n
−
l
)
∗
r
)
dp_{u,i}=max(dp_{u,i},dp_{u,l}+dp_{u,r}+l*(m-r)+(n-l)*r)
dpu,i=max(dpu,i,dpu,l+dpu,r+l∗(m−r)+(n−l)∗r)。
而对于翻转消费
c
o
s
t
i
cost_i
costi ,我们可以在初始化时候,将
d
p
i
,
a
[
i
]
=
−
c
o
s
t
i
dp_{i,a[i]}=-cost_i
dpi,a[i]=−costi ,
d
p
i
,
a
[
i
]
⊕
1
=
0
dp_{i,a[i] \oplus 1}=0
dpi,a[i]⊕1=0 即可,即表示成有
0
,
1
0,1
0,1 个白点的收益。
最后我们考虑复杂度,整体集合合并为 O ( n ) O(n) O(n), d p dp dp 转移的上界是 O ( n 2 ) O(n^2) O(n2) ,极端情况下复杂度会达到 O ( n 3 ) O(n^3) O(n3) ,这是我们无法接受的。
但是我们仔细考虑到 d p dp dp 转移的第二层是枚举其中一个集合的大小,此时我们考虑到启发式合并,每次枚举小的集合,这样最终的复杂度便是 O ( n 2 l o g n ) O(n^2logn) O(n2logn) ,能够通过此题。
Code
#include <bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long ll;
constexpr int N = 3010;
int p[N];
int leader(int x) {
while (x != p[x])x = p[x] = p[p[x]];
return x;
}
vector<ll>dp[N];
void merge(int u, int v, int w) {
u = leader(u), v = leader(v);
if (dp[u].size() > dp[v].size())swap(u, v);
int n = dp[u].size() - 1, m = dp[v].size() - 1;
vector<ll>temp(n + m + 1, -2e18);
for (int i = 0; i <= n + m; i++) {//枚举总的白点个数
for (int j = max(0, i - m); j <= min(i, n); j++) {//枚举u中的白点个数
//max(0,i-m)是由于i-j<=m,min(i,n)是有由于i<=n
int l = j, r = i - j;
int paircnt = l * (m - r) + (n - l) * r;
temp[i] = max(temp[i], dp[u][l] + dp[v][r] + 1ll * paircnt * w);
}
}
dp[u] = temp;
p[v] = u;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
for (int i = 1 ; i <= n; i++) p[i] = i, dp[i].resize(2);
vector<int>a(n + 1), cost(n + 1);
for (int i = 1; i <= n; i++)cin >> a[i];
for (int i = 1; i <= n; i++)cin >> cost[i];
for (int i = 1; i <= n; i++)dp[i][a[i]] = -cost[i];
vector<tuple<int, int, int>>edge;
for (int i = 1; i <= n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
edge.emplace_back(w, u, v);
}
sort(edge.begin(), edge.end());
for (auto [w, u, v] : edge) merge(u, v, w);
int root = leader(1);
ll ans = 0;
for (auto x : dp[root])ans = max(ans, x);
cout << ans << endl;
}