有一棵 n 个节点的树,树上每个节点都有一个正整数权值。如果一个点被选择了,那么在树上和它相邻的点都不能被选择。求选出的点的权值和最大是多少?
第一行包含一个整数 n 。
接下来的一行包含 n 个正整数,第 i 个正整数代表点 i 的权值。
接下来一共 n-1 行,每行描述树上的一条边。
1 2 3 4 5
1 2
1 3
2 4
2 5
对于20%的数据, n <= 20。
对于50%的数据, n <= 1000。
对于100%的数据, n <= 100000。
权值均为不超过1000的正整数。
树形DP来解决
DP[i][j] 节点i j=0为选取,j=1为不选
DP[i][0] += max(DP[q][0],DP[q][1])
DP[i][1] += DP[q][0];
我们倒着来,从子节点开始,最后输出max(dp[1][0],dp[1][1])
建个树开始DFS即可
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include<algorithm>
#include<iostream>
#include<math.h>
#include<queue>
using namespace std;
struct node
{
int v;
int next;
}edge[200020];
int head[100010];
int dp[100010][2];
int M;
void addEdge(int from,int to)
{
edge[M].v=to;
edge[M].next=head[from];
head[from]=M++;
edge[M].v=from;
edge[M].next=head[to];
head[to]=M++;
return ;
}
void dfs(int x,int pre)
{
for(int i=head[x];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(pre==v)
{
continue;
}
dfs(v,x);
dp[x][1]+=dp[v][0];
dp[x][0]+=max(dp[v][0],dp[v][1]);
}
}
int main()
{
int n;
cin>>n;
memset(head,-1,sizeof(head));
memset(dp,0,sizeof(dp));
for(int i=1;i<=n;i++)
{
cin>>dp[i][1];
}
for(int i=1;i<n;i++)
{
int a,b;
cin>>a>>b;
addEdge(a,b);
}
dfs(1,-1);
cout<<max(dp[1][1],dp[1][0])<<endl;
}
/*
string str;
string str2;
int str3[1100];
int main()
{
int n;
cin>>n;
while(n--)
{
cin>>str;
int len=str.length();
for(int i=len-1;i>=0;i--)
{
if(str[i]=='0')
{
str2+="0000";
break;
}
if(str[i]=='1')
{
str2+="0001";
break;
}
if(str[i]=='2')
{
str2+="0010";
break;
}
if(str[i]=='3')
{
str2+="0011";
break;
}
if(str[i]=='4')
{
str2+="0100";
break;
}
if(str[i]=='5')
{
str2+="0101";
break;
}
if(str[i]=='6')
{
str2+="0110";
break;
}
if(str[i]=='7')
{
str2+="0111";
break;
}
if(str[i]=='8')
{
str2+="1000";
break;
}
if(str[i]=='9')
{
str2+="1001";
break;
}
if(str[i]=='A')
{
str2+="1010";
break;
}
if(str[i]=='B')
{
str2+="1011";
break;
}
if(str[i]=='C')
{
str2+="1100";
break;
}
if(str[i]=='D')
{
str2+="1101";
break;
}
if(str[i]=='E')
{
str2+="1110";
break;
}
if(str[i]=='F')
{
str2+="1111";
break;
}
}
int shuchu=0;
int sum;
int t;
len=str2.length();
for(int i=0;i<len;i+=3)
{
sum=0;
t=2;
for(int j=i;j<i+3;j++)
{
int q=str2[j]-'0';
sum=sum+q*pow(2,t);
t--;
}
str3[shuchu++]=sum;
}
int i;
for( i=shuchu-1;i>=0;i--)
{
if(str3[i]!=0)
break;
}
for(int j=i;j>=0;j--)
{
cout<<str3[j];
}
cout<<endl;
}
}*/