树状DP,这题我写了2个Dfs来做,第一个以任意一点(我用的0节点)作为根节点,求出以每个点作为根节点的子树上机会为1~C时各自的最大值,然后第二次Dfs用父亲节点来更新儿子节点,求出每个点作为根节点能获得的最大值,做了之后看了看别人的题解,貌似只用一个Dfs就够了,而且代码量也比较少,等会儿再去看看吧。
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 50005;
struct node
{
int val , fg;
int maxval[5];//保存机会为1~C时的最大值
int maxpos[5];//保存取得最大值的路径
int smaxval[5];//保存机会为1~C时的次最大值
int smaxpos[5];//保存取得次最大值的路径
}tp[N];
struct Node
{
int ev , next;
}tree[N*2];
int n , c , tot , father[N];
void Addedge(int s , int e)
{
tree[tot].ev = e;
tree[tot].next = father[s];
father[s] = tot++;
}
void Dfs(int s , int fa)
{
int i , j , vs = 0;
memset(tp[s].maxval,0,sizeof tp[s].maxval);
memset(tp[s].smaxval,0,sizeof tp[s].smaxval);
for (i = father[s] ; i != -1 ; i = tree[i].next)
{
int v = tree[i].ev;
if (v == fa)continue;
vs = 1;
Dfs(v,s);
int cnt = tp[s].fg;
for (j = 1 ; j <= c ; j++)
{
if (tp[s].maxval[j] < tp[v].maxval[j-cnt]+tp[s].val)
{
tp[s].smaxval[j] = tp[s].maxval[j];
tp[s].smaxpos[j] = tp[s].maxpos[j];
tp[s].maxval[j] = tp[v].maxval[j-cnt]+tp[s].val;
tp[s].maxpos[j] = v;
}
else if (tp[s].smaxval[j] < tp[v].maxval[j-cnt]+tp[s].val)
{
tp[s].smaxval[j] = tp[v].maxval[j-cnt]+tp[s].val;
tp[s].smaxpos[j] = v;
}
}
}
if (vs)return;
for (j = 1 ; j <= c ; j++)
tp[s].maxval[j] = tp[s].val;
}
void Dfs1(int s , int fa)
{
int i , j;
for (i = father[s] ; i != -1 ; i = tree[i].next)
{
int v = tree[i].ev;
if (v == fa)continue;
int cnt = tp[s].fg;
for (j = 1 ; j <= c ; j++)
{
if (tp[v].maxval[j] < tp[s].maxval[j-cnt]+tp[v].val && tp[s].maxpos[j-cnt] != v)
{
tp[v].smaxval[j] = tp[v].maxval[j];
tp[v].smaxpos[j] = tp[v].smaxpos[j];
tp[v].maxval[j] = tp[s].maxval[j-cnt]+tp[v].val;
tp[v].maxpos[j] = s;
}
else if (tp[v].maxval[j] < tp[s].smaxval[j-cnt]+tp[v].val && tp[s].smaxpos[j-cnt] != v)
{
tp[v].smaxval[j] = tp[v].maxval[j];
tp[v].smaxpos[j] = tp[v].smaxpos[j];
tp[v].maxval[j] = tp[s].smaxval[j-cnt]+tp[v].val;
tp[v].maxpos[j] = s;
}
else if (tp[v].smaxval[j] < tp[s].maxval[j-cnt]+tp[v].val && tp[s].maxpos[j-cnt] != v)
{
tp[v].smaxval[j] = tp[s].maxval[j-cnt]+tp[v].val;
tp[v].smaxpos[j] = s;
}
else if (tp[v].smaxval[j] < tp[s].smaxval[j-cnt]+tp[v].val && tp[s].smaxpos[j-cnt] != v)
{
tp[v].smaxval[j] = tp[s].smaxval[j-cnt]+tp[v].val;
tp[v].smaxpos[j] = s;
}
}
Dfs1(v,s);
}
}
int main()
{
int t , i;
scanf("%d",&t);
while (t--)
{
memset(father , -1 , sizeof father);
tot = 0;
scanf("%d%d",&n,&c);
for (i = 0 ; i < n ; i++)scanf("%d%d",&tp[i].val,&tp[i].fg);
for (i = 1 ; i < n ; i++)
{
int s , e;
scanf("%d%d",&s,&e);
Addedge(s,e);
Addedge(e,s);
}
Dfs(0,-1);
Dfs1(0,-1);
int MAX = 0;
for (i = 0 ; i < n ; i++)MAX = max(MAX , tp[i].maxval[c]);
printf("%d\n",MAX);
}
return 0;
}