#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
int head[6005],in[6005],a[6005],tot,dp[6005][2];
struct node{
int to,next;
}tree[6005];
void bt(int u,int v){
tree[++tot].to=v;
tree[tot].next=head[u];
head[u]=tot;
}
void dfs(int u){
dp[u][0]=0;
dp[u][1]=a[u];
for(int i=head[u];i;i=tree[u].next){
int v=tree[i].to;
dfs(v);
dp[u][0]+=max(dp[v][0],dp[v][1]);//dp[v][0]>dp[v][1]?dp[v][0]:dp[v][1];
dp[u][1]+=dp[v][0];
}
}
int main(){
int n,x,y,root;
cin>>n;
for(int i=1;i<=n;++i)
cin>>a[i];
for(int i=1;i<n;++i){
cin>>x>>y;
bt(y,x);
in[x]++;
}
cin>>x>>y;
for(int i=1;i<=n;++i)
if(in[i]==0){
root=i;
break;
}
dfs(root);
printf("%d",max(dp[root][0],dp[root][1]));
}
WA了4个点。求助!