3257: 树的难题
Time Limit: 10 Sec Memory Limit: 128 MB
Submit: 56 Solved: 39
[Submit][Status][Discuss]
Description
给出一个无根树。树有N个点,边有权值。每个点都有颜色,是黑色、白色、
灰色这三种颜色之一,称为一棵三色树。
可爱的 Alice觉得,一个三色树为均衡的,当且仅当,树中不含有黑色结点
或者含有至多一个白色节点。然而,给出的三色树可能并不满足这个性质。
所以,Alice打算删去若干条边使得形成的森林中每棵树都是均衡的,花费
的代价等于删去的边的权值之和。请你计算需要花费的代价最小是多少。
注意,输入文件包含多组测试数据。
Input
第一行包含一个正整数 T,表示有 T组测试数据。接下来依次是 T组测试数
据。每组测试数据的第一行包含一个正整数 N。
第二行包含 N个 0、1、2之一的整数,依次表示点 1到点 N的颜色。其中,
0 表示黑色,1表示白色,2表示灰色。
接下来 N-1行,每行为三个整数ui、vi、c i,表示一条权值等于ci的边(ui, vi)。
树形DP
当至多取一个的时候,我们需要计算出来不取的答案,然后减掉取min,即至多取一个。
#include
#include
#include
#include
#define maxn 500010
using namespace std;
typedef long long ll;
ll dp[maxn][3];
int n;
struct Edge{
int to, next, dis;
}edge[maxn * 2];
int h[maxn], cnt;
void add(int u, int v, int d){
cnt ++;
edge[cnt].to = v;
edge[cnt].next = h[u];
edge[cnt].dis = d;
h[u] = cnt;
}
const ll inf = 1ll << 50;
int fa[maxn], c[maxn];
ll d[maxn];
int Min(int a, int b, int c){
if(b < a)a = b;
if(c < a)a = c;
return a;
}
/*
dp[u][2]//一个白点。
dp[u][1]//无白点
dp[u][0]//无黑点
*/
void dfs(int u){
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
fa[v] = u;
d[v] = edge[i].dis;
dfs(v);
}
if(c[u] == 0){//此点为黑。至多一个白点。
ll ret = 0, q = 0;
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
ll t = min(dp[v][0], dp[v][2]) + d[v];
ret += min(dp[v][1], min(dp[v][0], dp[v][2]) + d[v]);
}
q = ret;
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
q = min(q, ret - min(dp[v][1], min(dp[v][0], dp[v][2]) + d[v]) + dp[v][2]);
}
dp[u][2] = q;
dp[u][1] = ret;
dp[u][0] = inf;
}
if(c[u] == 1){//此点为白。不含黑点/子树不含白点。
ll ret1 = 0, ret2 = 0;
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
ret1 += min(dp[v][0], min(dp[v][1], dp[v][2]) + d[v]);
ret2 += min(dp[v][1], min(dp[v][2], dp[v][0]) + d[v]);
}
dp[u][2] = ret2;
dp[u][1] = inf;
dp[u][0] = ret1;
}
if(c[u] == 2){//此点为灰。不含黑点/至多一个白点。
ll ret1 = 0, ret = 0, q = 0;
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
ret1 += min(dp[v][0], min(dp[v][1], dp[v][2]) + d[v]);
ret += min(dp[v][1], min(dp[v][0], dp[v][2]) + d[v]);
}
q = ret;
for(int i = h[u]; i; i = edge[i].next){
int v = edge[i].to;
if(v == fa[u])continue;
q = min(q, ret - min(dp[v][1], min(dp[v][0], dp[v][2]) + d[v]) + dp[v][2]);
}
dp[u][2] = q;
dp[u][1] = ret;
dp[u][0] = ret1;
}
}
int main(){
int test;
scanf("%d", &test);
while(test --){
scanf("%d", &n);
for(int i = 1; i <= n; i ++)
scanf("%d", &c[i]);
cnt = 0;
memset(h, 0, sizeof h);
for(int i = 1; i <= n; i ++)
dp[i][0] = dp[i][1] = dp[i][2] = inf;
int u, v, dis;
for(int i = 1; i < n; i ++){
scanf("%d%d%d", &u, &v, &dis);
add(u, v, dis);
add(v, u, dis);
}
d[1] = 0;dfs(1);
ll ans = min(dp[1][0], dp[1][1]);
ans = min(ans, dp[1][2]);
printf("%lld\n", ans);
}
return 0;
}