这道题和昨晚在CF上做的一道题很像,都是树形dp。刚开始时,我写了个很烂的代码水过了,时间复杂度为O(n^3).不过一看觉得不对了,人家的都是0ms的,于是去网上找题解优化,时间就降为O(n^2)了,终究还是自己太水了……
优化前,300ms:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
template<class T> T Max(T x,T y){return x>y?x:y;}
template<class T> T Min(T x,T y){return x<y?x:y;}
#define N 105
vector<int> v[N];
int m,a[N],dp[N][N];
void dfs(int x,int y)
{
int i,j,k,z,len;
dp[x][1]=a[x];
len=v[x].size();
for(i=0;i<len;i++)
{
z=v[x][i];
if(y==z)continue;
dfs(z,x);
for(j=m;j>0;j--)
for(k=1;k<=j;k++)
dp[x][j]=Max(dp[x][j],dp[x][k]+dp[z][j-k]); //状态转移
}
}
int main()
{
//freopen("a.txt","r",stdin);
int i,n,x,y,ans;
while(scanf("%d%d",&n,&m)!=EOF)
{
for(i=0;i<n;i++)
{
scanf("%d",a+i);
v[i].clear();
}
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
ans=0;
for(i=0;i<n;i++) //这个循环浪费了时间
{
memset(dp,0,sizeof(dp));
dfs(i,-1);
ans=Max(ans,dp[i][m]);
}
printf("%d\n",ans);
}
return 0;
}
优化后,0ms:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
template<class T> T Max(T x,T y){return x>y?x:y;}
template<class T> T Min(T x,T y){return x<y?x:y;}
#define N 105
vector<int> v[N];
int m,a[N],dp[N][N];
void dfs(int x,int y)
{
int i,j,k,z,len;
dp[x][1]=a[x];
len=v[x].size();
for(i=0;i<len;i++)
{
z=v[x][i];
if(y==z)continue;
dfs(z,x);
for(j=m;j>0;j--)
for(k=j;k>0;k--) //这里是优化最关键的地方,由大到小就可以防止覆盖
dp[x][j]=Max(dp[x][j],dp[x][k]+dp[z][j-k]);
}
}
int main()
{
//freopen("a.txt","r",stdin);
int i,n,x,y,ans;
while(scanf("%d%d",&n,&m)!=EOF)
{
for(i=0;i<n;i++)
{
scanf("%d",a+i);
v[i].clear();
}
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
memset(dp,0,sizeof(dp));
dfs(0,-1);
ans=0;
for(i=0;i<n;i++)ans=Max(ans,dp[i][m]);
printf("%d\n",ans);
}
return 0;
}