思路:首先这是一道01分数规划,二分枚举答案
也就是 成立,l=mid,提高答案, 否则 r=mid, 减小答案
然后是树形dp,有两种方法:
第一种:直接dp, dp[i][j]为i节点必选,以i为根节点的子树选择j个的最大值
转移方程为 rt根节点+其他子树选j个, 该子树选c个
第二种:先求出dfs序, 再dp;
if(dp[i][j]+val[i]>dp[i+1][j+1]) 由该节点转移到孩子节点
if(dp[i][j]>dp[out[i]][j]) 由该节点转移到其他子树;
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e3+600;
const int mo=998244353;
const double eps=1e-4;
int n, k, si[N], pi[N], sz[N];
double val[N];
vector<int> G[N];
double dp[N][N];
int dcmp(double x){
if(fabs(x)<eps)
return 0;
else
return x<0?-1:1;
}
void dfs(int rt){
dp[rt][1]=val[rt];
sz[rt]=1;
for(int i=0; i<G[rt].size(); i++){
int to=G[rt][i];
dfs(to);
int up=min(k, sz[rt]);
for(int j=up; j>=1; j--){
for(int c=1; c<=sz[to]; c++){
if(j+c>k) break;
dp[rt][j+c]=max(dp[rt][j+c], dp[rt][j]+dp[to][c]);
}
}
sz[rt]+=sz[to];
}
}
int main(){
scanf("%d%d", &k, &n);
int f, mx=-1;
for(int i=1; i<=n; i++){
scanf("%d%d%d", &si[i], &pi[i], &f);
G[f].push_back(i);
mx=max(mx, pi[i]);
}
k++;
val[0]=0.0;
double l=0.0, r=mx*1.0;
while(l+eps<r){
for(int i=0; i<=n; i++)
for(int j=0; j<=k; j++)
dp[i][j]=-999999999;
double mid=(l+r)/2;
for(int i=1; i<=n; i++)
val[i]=1.0*pi[i]-mid*si[i];
dfs(0);
if(dp[0][k]>0)
l=mid;
else
r=mid;
}
printf("%.3lf\n", (l+r)/2);
return 0;
}
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e3+600;
const int mo=998244353;
const double eps=1e-4;
int n, k, si[N], pi[N], out[N], cnt, dfn[N];
double val[N];
vector<int> G[N];
double dp[N][N];
void dfs(int x){
dfn[x]=cnt++;
for(int i=0; i<G[x].size(); i++) dfs(G[x][i]);
out[dfn[x]]=cnt;
}
void work(double mid){
for(int i=0; i<=cnt; i++)
for(int j=0; j<=k; j++)
dp[i][j]=-9999999999;
for(int i=1; i<=cnt; i++)
val[dfn[i]]=pi[i]-mid*si[i];
dp[0][0]=0;
for(int i=0; i<=n; i++){
int up=min(i+1, k);
for(int j=0; j<=up; j++){
if(dp[i][j]+val[i]>dp[i+1][j+1])
dp[i+1][j+1]=dp[i][j]+val[i];
if(dp[i][j]>dp[out[i]][j])
dp[out[i]][j]=dp[i][j];
}
}
}
int main(){
scanf("%d%d", &k, &n);
int f, mx=-1;
for(int i=1; i<=n; i++){
scanf("%d%d%d", &si[i], &pi[i], &f);
G[f].push_back(i);
mx=max(mx, pi[i]);
}
dfs(0);
k++;
val[0]=0.0;
double l=0.0, r=mx*1.0;
while(l+eps<r){
double mid=(l+r)/2;
work(mid);
if(dp[cnt][k]>0)
l=mid;
else
r=mid;
}
printf("%.3lf\n", (l+r)/2);
return 0;
}