传送门:HDU 5834
题意:给你一棵树,边有边权,每经过边一次,就得支付过路费c[i],点有点权,每个点只能获得一次。
问从每个点出发,能够获得的最大权值是多少?
思路:经典树形DP, dp[i][0]表示从i到以i为根的子树中去,再回到i能获得的最大权值。
dp[i][1]表示从i出发到以i为根的子树中不回到i点的最大权值。
f[i][0]表示从i出发去i的父节点方向再回到i点能获得的最大权值。
f[i][1]表示从i出发去i的父节点方向不回到i的最大权值。
dp[i][0/1]还是很好想的,难点是求f数组,因为涉及到去往哪个子树然后不回来的问题,因此求f数组的时候很考验思维的全面性,这题还是很值得好好研究的。
要注意的是不一定所有子树都要去一遍,因为去某个子树可能不能获得正权值,因此很多时候要取max(0,)
id[i]代表从i出发去往哪个子树不回来了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
using namespace std;
typedef pair<int,int> P;
const int MAXN = 100010;
struct node{
int v, w, next;
node(){}
node(int _v, int _w, int _next) : v(_v), w(_w), next(_next) {}
}mp[MAXN * 2];
int pre[MAXN], val[MAXN], dp[MAXN][2], f[MAXN][2], id[MAXN];
int cnt;
void add(int u, int v, int w)
{
mp[cnt] = node(v, w, pre[u]); pre[u] = cnt++;
mp[cnt] = node(u, w, pre[v]); pre[v] = cnt++;
}
void dfs1(int u, int fa)
{
dp[u][0] = dp[u][1] = val[u];
int v, w;
for(int i = pre[u]; ~i; i = mp[i].next)
{
v = mp[i].v; w = mp[i].w;
if(v == fa) continue;
dfs1(v, u);
dp[u][1] += max(dp[v][0] - 2 * w, 0);
if(dp[u][0] + dp[v][1] - w > dp[u][1])
{
dp[u][1] = dp[u][0] + dp[v][1] - w;
id[u] = v;
}
dp[u][0] += max(dp[v][0] - 2 * w, 0);
}
}
void dfs2(int u, int fa)
{
int v, w, tmp;
if(dp[u][1] + f[u][0] < f[u][1] + dp[u][0])
id[u] = fa;
for(int i = pre[u]; ~i; i = mp[i].next)
{
v = mp[i].v; w = mp[i].w;
if(v == fa) continue;
if(v == id[u])
{
int vv, ww;
int tmp0 = f[u][0] + val[u], tmp1 = f[u][1] + val[u];
for(int j = pre[u]; ~j; j = mp[j].next)
{
vv = mp[j].v; ww = mp[j].w;
if(vv == fa || vv == v) continue;
tmp1 = max(tmp1 + max(dp[vv][0] - 2 * ww, 0), tmp0 + max(dp[vv][1] - ww, 0));
tmp0 += max(dp[vv][0] - 2 * ww, 0);
}
f[v][1] = max(0, tmp1 - w);
f[v][0] = max(0, tmp0 - 2 * w);
}
else
{
if(dp[v][0] >= 2 * w)
tmp = dp[v][0] - 2 * w;
else
tmp = 0;
f[v][0] = max(0, f[u][0] + dp[u][0] - tmp - 2 * w);
f[v][1] = max(0, max(f[u][1] + dp[u][0] - tmp - w, f[u][0] + dp[u][1] - tmp - w));
}
dfs2(v, u);
}
}
int main()
{
int T, n, u, v, w;
int kase = 1;
cin >> T;
while(T--)
{
cnt = 0;
scanf("%d", &n);
for(int i = 1; i <= n; i++)
{
scanf("%d", val + i);
dp[i][0] = dp[i][1] = 0;
f[i][0] = f[i][1] = 0;
pre[i] = id[i] = -1;
}
for(int i = 1; i < n; i++)
{
scanf("%d %d %d", &u, &v, &w);
add(u, v, w);
}
dfs1(1, -1);
//for(int i = 1; i <= n; i++)
// printf("%d %d %d\n", dp[i][0], dp[i][1], id[i]);
dfs2(1, -1);
// for(int i = 1; i <= n; i++)
// printf("%d %d %d\n", f[i][0], f[i][1], id[i]);
printf("Case #%d:\n", kase++);
for(int i = 1; i <= n; i++)
printf("%d\n", max(dp[i][0] + f[i][1], dp[i][1] + f[i][0]));
}
return 0;
}