首先肯定是01分数规划,先二分答案,设当前答案是
x
x
,那么每个点的权值就是,我们要在树上取一个包含节点0的大小为
k+1
k
+
1
的连通块(因为有0所以+1),使得权值最大,然后看这个最大权值是否大于等于0
这个可以考虑树型背包
这个看上去是
O(n3)
O
(
n
3
)
的,但有一个小技巧可以让他变成
O(n2)
O
(
n
2
)
我们考虑一个和传统树型dp不太一样的状态
设
dp[i][j]
d
p
[
i
]
[
j
]
表示当前考虑到原树的dfs序的第i个节点,还打算在后面的节点中选j个点的最大收益
- 要在i为根的子树内选择一部分节点,那么至少得选择i这个节点, dp[i][j]=max(dp[i][j],dp[i+1][j−1]+ai−xbi) d p [ i ] [ j ] = m a x ( d p [ i ] [ j ] , d p [ i + 1 ] [ j − 1 ] + a i − x b i )
- 不在i为根的子树内选择节点,那么直接跳过以i为根的子树,我们知道在dfs序中以i为根的子树是连续的一段, dp[i][j]=max(dp[i][j],dp[i+size[i]][j]) d p [ i ] [ j ] = m a x ( d p [ i ] [ j ] , d p [ i + s i z e [ i ] ] [ j ] )
特别的,如果
i+size[i]>n+1
i
+
s
i
z
e
[
i
]
>
n
+
1
且
j≠0
j
≠
0
,那么必须在以i为根的子树内选择节点,没有第二种转移
总复杂度
O(n2logn)
O
(
n
2
l
o
g
n
)
最好写从后向前的递推版本,记忆化搜索有点慢
#include <cassert>
#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,LL>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;
const int MOD=100003;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=acos(-1);
inline int getint()
{
char ch;int res;bool f;
while (!isdigit(ch=getchar()) && ch!='-') {}
if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
while (isdigit(ch=getchar())) res=res*10+ch-'0';
return f?res:-res;
}
int n,k;
vector<int> v[2548];
double a[2548],b[2548];
int seq[2548],ind=0;
double dp[2548][2548];
int sz[2548];
double l,r,mid;
inline void dfs(int cur)
{
int i,y;sz[cur]=1;seq[++ind]=cur;
for (i=0;i<int(v[cur].size());i++)
{
y=v[cur][i];dfs(y);
sz[cur]+=sz[y];
}
}
inline bool check()
{
int i,j;
for (i=1;i<=n+2;i++)
for (j=0;j<=k;j++)
dp[i][j]=-INF;
for (i=1;i<=n+2;i++) dp[i][0]=0;
for (i=n+1;i>=1;i--)
for (j=1;j<=min(k,n+2-i);j++)
if (i+sz[seq[i]]<=n+1)
dp[i][j]=max(dp[i+1][j-1]+a[seq[i]]-b[seq[i]]*mid,dp[i+sz[seq[i]]][j]);
else
dp[i][j]=dp[i+1][j-1]+a[seq[i]]-b[seq[i]]*mid;
return dp[1][k]>=0;
}
int main ()
{
int i,x;
k=getint();n=getint();k++;
for (i=1;i<=n;i++)
{
b[i]=getint();a[i]=getint();
x=getint();v[x].pb(i);
}
dfs(0);
l=0;r=1e4;double ans=0;
while (r-l>1e-5)
{
mid=(l+r)/2;
if (check()) ans=mid,l=mid; else r=mid;
}
printf("%.3lf\n",ans);
return 0;
}