题目
代码1
#include <bits/stdc++.h>
using namespace std;
const int N = 6010, M = N;
const int null = 0x3f3f3f3f;
int n;
int a[N];
int h[N], e[M], ne[M], idx;
int f[N][2];
bool st[N];
int dfs(int u, int i)
{
if(h[u] == -1)
{
if(i) return a[u];
else return 0;
}
if(f[u][i] != null) return f[u][i];
int retval = 0;
for(int k = h[u]; ~k; k = ne[k])
{
int j = e[k];
if(i) retval += (f[j][0] = dfs(j, 0));
else retval += max((f[j][0] = dfs(j, 0)), (f[j][1] = dfs(j, 1)));
}
if(i) retval += a[u];
return (f[u][i] = retval);
}
void add(int u, int v)
{
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
int main()
{
int n;
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
memset(h, -1, sizeof h);
for(int i = 0; i < n-1; i++)
{
int v, u;
cin >> v >> u;
add(u, v);
st[v] = true;
}
int root;
for(int i = 1; i <= n; i++)
{
if(!st[i])
{
root = i;
break;
}
}
memset(f, 0x3f, sizeof f);
int res = max(dfs(root, 0), dfs(root, 1));
cout << res;
return 0;
}
代码2
#include <bits/stdc++.h>
using namespace std;
const int N = 6010, M = N;
const int null = 0x3f3f3f3f;
int n;
int a[N];
int h[N], e[M], ne[M], idx;
int f[N][2];
bool st[N];
void dfs(int u)
{
f[u][1] = a[u];
for(int k = h[u]; ~k; k = ne[k])
{
int j = e[k];
dfs(j);
f[u][0] += max(f[j][0], f[j][1]);
f[u][1] += f[j][0];
}
}
void add(int u, int v)
{
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
int main()
{
int n;
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i];
memset(h, -1, sizeof h);
for(int i = 0; i < n-1; i++)
{
int v, u;
cin >> v >> u;
add(u, v);
st[v] = true;
}
int root;
for(int i = 1; i <= n; i++)
{
if(!st[i])
{
root = i;
break;
}
}
dfs(root);
int res = max(f[root][0], f[root][1]);
cout << res;
return 0;
}
这个代码更好,每次考虑一个节点的两个f,递归节点数量就少了;同时不会重复递归;递归仅仅体现先后,不再返回值,分工明确