题意
给出一个节点数为 N N N的无根树,每个点有白、黑、灰三个颜色。
去掉一些边使得剩下的分割的集合满足:
1
、
1、
1、不含有黑色节点
2
、
、
、含有至多一个白色节点
求出去掉的边的最小边权和。
思路
显然树形
d
p
dp
dp。
d
f
s
dfs
dfs会爆栈,用
b
f
s
bfs
bfs进行类似拓扑一样的顺序进行
d
p
dp
dp。
代码
#include<queue>
#include<cstdio>
#include<cstring>
#include<algorithm>
const int N = 300001;
const long long inf = 1e17;
int t, n, tot, root;
int c[N], deg[N];
int ver[2 * N], edge[2 * N], next[2 * N], head[N];
long long f[N], g[N], h[N];
std::queue<int> q;
void add(int u, int v, int w) {
ver[++tot] = v;
edge[tot] = w;
next[tot] = head[u];
head[u] = tot;
ver[++tot] = u;
edge[tot] = w;
next[tot] = head[v];
head[v] = tot;
}
long long bfs() {
while (q.size()) q.pop();
for (int i = 1; i <= n; i++)
if (deg[i] == 1) q.push(i);
while (q.size()) {
int x = q.front();
q.pop();
deg[x]--;
root = x;
if (c[x] == 0) {
f[x] = inf;
for (int i = head[x]; i; i = next[i])
if (!deg[ver[i]]) g[x] += std::min(std::min(f[ver[i]] + edge[i], g[ver[i]]), h[ver[i]] + edge[i]);
for (int i = head[x]; i; i = next[i]) {
if (!deg[ver[i]]) h[x] = std::min(h[x], h[ver[i]] + g[x] - std::min(std::min(f[ver[i]] + edge[i], g[ver[i]]), h[ver[i]] + edge[i]));
}
} else if (c[x] == 1) {
g[x] = inf;
for (int i = head[x]; i; i = next[i])
if (!deg[ver[i]]) {
f[x] += std::min(std::min(f[ver[i]], g[ver[i]] + edge[i]), h[ver[i]] + edge[i]);
h[x] += std::min(std::min(f[ver[i]] + edge[i], g[ver[i]]), h[ver[i]] + edge[i]);
}
} else if (c[x] == 2) {
for (int i = head[x]; i; i = next[i])
if (!deg[ver[i]]) {
f[x] += std::min(std::min(f[ver[i]], g[ver[i]] + edge[i]), h[ver[i]] + edge[i]);
g[x] += std::min(std::min(f[ver[i]] + edge[i], g[ver[i]]), h[ver[i]] + edge[i]);
}
for (int i = head[x]; i; i = next[i])
if (!deg[ver[i]]) h[x] = std::min(h[x], h[ver[i]] + g[x] - std::min(std::min(f[ver[i]] + edge[i], g[ver[i]]), h[ver[i]] + edge[i]));
}
for (int i = head[x]; i; i = next[i])
if (deg[ver[i]] > 1) {
deg[ver[i]]--;
if (deg[ver[i]] == 1) q.push(ver[i]);
}
}
return std::min(std::min(f[root], g[root]), h[root]);
}
int main() {
scanf("%d", &t);
for (; t; t--) {
memset(head, 0, sizeof(head));
memset(deg, 0, sizeof(deg));
tot = 0;
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &c[i]);
if (c[i] == 0) f[i] = inf, g[i] = 0, h[i] = inf;
else if (c[i] == 1) f[i] = 0, g[i] = inf, h[i] = 0;
else f[i] = 0, g[i] = 0, h[i] = inf;
}
for (int i = 1, x, y, z; i < n; i++)
scanf("%d %d %d", &x, &y, &z), add(x, y, z), deg[x]++, deg[y]++;
printf("%lld\n", bfs());
}
}