题意:给出一棵树,每个节点有一个正整数值,可能有1个或者没有炸弹,要求你找出一条单向路,在碰到炸弹数量不超过C的情况下取得的值尽可能大,起点任意,需要注意的是,碰到C个炸弹之后,立刻结束,哪怕还有没有炸弹的节点都不能走了,因为一个小BUG,多校结束之后20分钟才A了,太可惜了,希望错误不要再犯!
思路:典型树形DP,大致做法差不多,定义DP[ I ][ J ][ K ],以I节点为跟的树走过J个炸弹以K为方向的最优解,方向有入跟出两种,处理下当炸弹数目等于C时候的情况就可以了。
#include <cstdio>
#include <cstring>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <map>
#include <set>
#include <sstream>
#include <iostream>
#include <algorithm>
#include<cstdlib>
using namespace std;
#define N 100100
#define L(x) x<<1
#define R(x) x<<1|1
#define M(x,y) (x + y)>>1
#define MOD 1000000007
#define MODD 1000000006
#define inf 0x7fffffff
#define llinf 0x7fffffffffffffff
#define LL __int64
struct edge
{
LL u,v,next;
}s[N*2];
struct st
{
LL num;
LL c;
}a[N];
LL cnt,visit[N];
LL mark[N];
LL dp[N][4][2];
LL marks[4],ans,n,m;
void addedge(LL a,LL b)
{
s[cnt].u = a;
s[cnt].v = b;
s[cnt].next = visit[a];
visit[a] = cnt++;
s[cnt].u = b;
s[cnt].v = a;
s[cnt].next = visit[b];
visit[b] = cnt++;
}
void dfs(LL x)
{
mark[x] = 1;
dp[x][a[x].c][0] = dp[x][a[x].c][1] = a[x].num;
for(LL i = visit[x];i != -1;i = s[i].next)
{
if(mark[s[i].v] == 1)
continue;
dfs(s[i].v);
for(LL j = 0;j <= m;j++)
{
for(LL k = 0;k + j <= m;k++)
{
if(j != m)
{
// cout<<"###"<<endl;
// cout<<j<<' '<<dp[x][j][0]<<' '<<dp[s[i].v][k][1]<<endl;
ans = max(ans,dp[x][j][0] + dp[s[i].v][k][1]);
}
if(k != m)
{
// cout<<"@@@"<<endl;
// cout<<k<<' '<<dp[x][j][1]<<' '<<dp[s[i].v][k][0]<<endl;;
ans = max(ans,dp[x][j][1] + dp[s[i].v][k][0]);
}
if(j + k + 1 <= m)
{
ans = max(ans,dp[x][j][0] + dp[s[i].v][k][0]);
}
}
}
for(LL j = 0;j + a[x].c <= m;j++)
{
dp[x][j + a[x].c][0] = max(dp[x][j + a[x].c][0],dp[s[i].v][j][0] + a[x].num);
}
for(LL j = 1;j + a[x].c <= m;j++)
{
dp[x][j + a[x].c][1] = max(dp[x][j + a[x].c][1],dp[s[i].v][j][1] + a[x].num);
}
}
/* cout<<"show "<<x<<endl;
for(LL i = 0;i <= m;i++)
{
cout<<dp[x][i][0]<<' '<<dp[x][i][1]<<endl;
}
cout<<endl<<endl;*/
}
int main()
{
LL i,j,k;
LL t;
scanf("%I64d",&t);
while(t--)
{
scanf("%I64d%I64d",&n,&m);
cnt = 0;
memset(visit,-1,sizeof(visit));
memset(mark,0,sizeof(mark));
memset(dp,0,sizeof(dp));
for(i = 0;i < n;i++)
{
scanf("%I64d%I64d",&a[i].num,&a[i].c);
}
for(i = 1;i < n;i++)
{
scanf("%I64d%I64d",&j,&k);
addedge(j,k);
}
ans = 0;
dfs(0);
printf("%I64d\n",ans);
}
return 0;
}