题目大意:
给定一棵有 n 个点的树,第 i 个点有 di 件商品,价格为 ci,价值为 wi。
你手头有 m 块钱,且你要保证你买过的点在树上互相连通,问买到的物品的总价值最多是多少。
1 ≤ n ≤ 500, 1 ≤ m ≤ 4000, di ≤ 100。
解题思路:
如果直接树形dp是 O(nm2d) O ( n m 2 d ) 的,显然过不了
考虑如果强制要选一个点怎么做。
相当于以该点为根,那么就是一个有依赖的多重背包(就是选了子树根才能选子树中的点),直接在树上做也是
O(nm2)
O
(
n
m
2
)
,这里有一个套路:先搞出dfs序,设f[i]表示考虑了dfs序后i个点,花了j的钱的答案,如果不选就跳过整棵子树的区间,如果选就从i+1转移过来即可。这样就转到了序列上,可以用单调队列优化到
O(nm)
O
(
n
m
)
,详见这里 。
如果我们用点分治,每次把重心设为必选点,总复杂度就是 O(nmlogn) O ( n m l o g n ) 的,而且可以保证正确性。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
if(c=='-')c=getchar(),f=-1;
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=505,M=4005;
int n,m,ans,w[N],c[N],d[N];
int root,totsize,idx,size[N],maxsub[N],id[N],vis[N],q[N],f[N][M];
int tot,first[N],nxt[N<<1],to[N<<1];
void add(int x,int y)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y;
}
void getroot(int u,int fa)
{
size[u]=1,maxsub[u]=0;
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];if(v==fa||vis[v])continue;
getroot(v,u),size[u]+=size[v];
if(size[v]>maxsub[u])maxsub[u]=size[v];
}
maxsub[u]=max(maxsub[u],totsize-size[u]);
if(maxsub[u]<maxsub[root])root=u;
}
void dfs(int u,int fa)
{
id[++idx]=u,size[u]=1;
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];if(v==fa||vis[v])continue;
dfs(v,u),size[u]+=size[v];
}
}
void solve(int u)
{
vis[u]=1,idx=0,dfs(u,0);
memset(f[idx+1],0,sizeof(f[idx+1]));
for(int i=idx;i;i--)
{
int u=id[i],p=c[u],h=w[u],cnt=d[u],head=1,tail=0;
for(int j=0;j<=m;j++)f[i][j]=f[i+size[u]][j];
for(int j=0;j<p;j++)
{
int head=1,tail=0;
for(int k=0;j+k*p<=m;k++)
{
while(head<=tail&&q[head]<k-cnt)head++;
if(head<=tail)f[i][j+k*p]=max(f[i][j+k*p],f[i+1][j+q[head]*p]+(k-q[head])*h);
while(head<=tail&&f[i+1][j+k*p]-k*h>=f[i+1][j+q[tail]*p]-q[tail]*h)tail--;
q[++tail]=k;
}
}
}
for(int i=1;i<=m;i++)ans=max(ans,f[1][i]);
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];if(vis[v])continue;
maxsub[root=0]=totsize=size[v],getroot(v,u),solve(root);
}
}
int main()
{
freopen("lx.in","r",stdin);
freopen("lx.out","w",stdout);
for(int T=getint();T;T--)
{
ans=tot=0;
memset(first,0,sizeof(first));
memset(vis,0,sizeof(vis));
n=getint(),m=getint();
for(int i=1;i<=n;i++)w[i]=getint();
for(int i=1;i<=n;i++)c[i]=getint();
for(int i=1;i<=n;i++)d[i]=getint();
for(int i=1;i<n;i++)
{
int x=getint(),y=getint();
add(x,y),add(y,x);
}
totsize=maxsub[root=0]=n,getroot(1,0),solve(root);
printf("%d\n",ans);
}
return 0;
}