题目大意
给你一棵 n n n个点的树,每个点都有一个薪水 C i C_i Ci和领导力 L i L_i Li,你需要选择一个点 v v v,并在 v v v的子树中选若干个点,使得 ∑ C i ≤ m \sum C_i\leq m ∑Ci≤m,那么当前的满意度为 L v ∑ C i L_v\sum C_i Lv∑Ci。求最大的满意度。
1 ≤ n ≤ 1 0 5 , 1 ≤ m ≤ 1 0 9 , 1 ≤ C i ≤ m , 1 ≤ L i ≤ 1 0 9 1\leq n\leq 10^5,1\leq m\leq 10^9,1\leq C_i\leq m,1\leq L_i\leq 10^9 1≤n≤105,1≤m≤109,1≤Ci≤m,1≤Li≤109
题解
首先,对于每个点 v v v,求出最大的 ∑ C i \sum C_i ∑Ci。那么,我们可以先求这棵树的重链。用 v e c t o r vector vector来存每个点的 C i C_i Ci,先用重儿子的 v e c t o r vector vector,再在此基础上加 C i C_i Ci,最后,不断删去 v e c t o r vector vector中值最大的点,直到 ∑ C i ≤ m \sum C_i\leq m ∑Ci≤m。
因为每个点到根节点最多只会被 log n \log n logn个点加入,最多只会被 log n \log n logn次重链算入,每次插入的时间复杂度为 O ( log n ) O(\log n) O(logn),所以总时间复杂度为 O ( n log 2 n ) O(n\log^2n) O(nlog2n)。
设 m x v = ∑ C i mx_v=\sum C_i mxv=∑Ci,则求出最大的 L v × m x v L_v\times mx_v Lv×mxv即可。
时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
这道题实际上用的是 s e t set set的思维,只不过用 s e t set set的话不方便,所以要用 v e c t o r vector vector模拟 s e t set set。
当然,这道题也可以用线段树合并,但没学过的话,用上面这个方法是挺不错的。
code
#include<bits/stdc++.h>
using namespace std;
int n,m,x,tot=0,c[100005],w[100005],d[100005],l[100005],r[100005],siz[100005],son[100005],mx[100005];
long long ans=0,v[100005];
vector<int>s[100005];
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs1(int u,int fa){
siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(d[i]==fa) continue;
dfs1(d[i],u);
siz[u]+=siz[d[i]];
if(siz[d[i]]>siz[son[u]]) son[u]=d[i];
}
}
void dfs2(int u,int fa){
if(son[u]){
dfs2(son[u],u);
swap(s[u],s[son[u]]);
v[u]=v[son[u]];
}
vector<int>::iterator wt=s[u].end();--wt;
if(s[u].empty()||c[u]>*wt) s[u].push_back(c[u]);
else{
int vt=upper_bound(s[u].begin(),s[u].end(),c[u])-s[u].begin();
s[u].insert(s[u].begin()+vt,c[u]);
}
v[u]+=c[u];
for(int i=r[u];i;i=l[i]){
if(d[i]==fa||d[i]==son[u]) continue;
dfs2(d[i],u);
if(s[d[i]].empty()) continue;
for(vector<int>::iterator it=s[d[i]].begin();it!=s[d[i]].end();++it){
wt=s[u].end();--wt;
if(s[u].empty()||*it>*wt) s[u].push_back(*it);
else{
int vt=upper_bound(s[u].begin(),s[u].end(),*it)-s[u].begin();
s[u].insert(s[u].begin()+vt,*it);
}
v[u]+=*it;
}
s[d[i]].clear();
}
for(vector<int>::iterator it=s[u].end();it!=s[u].begin();){
if(v[u]<=m) break;
--it;
v[u]-=*it;
s[u].erase(it);
}
mx[u]=s[u].size();
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d%d%d",&x,&c[i],&w[i]);
if(i>1) add(x,i);
}
dfs1(1,0);
dfs2(1,0);
for(int i=1;i<=n;i++){
ans=max(ans,1ll*mx[i]*w[i]);
}
printf("%lld",ans);
return 0;
}