【题解】
离散化高度和花费,建立以价格为下标的前缀数量和与前缀花费和的树状数组,从高到低按高度处理,每次记录加上必定需要被下一个高度删去的当前高度的所有树木的花费,然后二分处理出需要砍掉的花费尽可能小的区间,最后算出花费更新答案即可。
【代码】
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
#define ll long long
ll num[maxn],sum[maxn];
int hh[maxn],hm;
int cc[maxn],cm;
ll hnum[maxn];
int lowbit(int x)
{
return x&(-x);
}
void add(int u,int num0,int val0)
{
for(;u<=cm;u+=lowbit(u)){
num[u]+=num0;
sum[u]+=1ll*num0*val0;
}
}
ll suma(int u)
{
ll res=0;
for(;u>=1;u-=lowbit(u))
res+=num[u];
return res;
}
ll sumb(int u)
{
ll res=0;
for(;u>=1;u-=lowbit(u))
res+=sum[u];
return res;
}
struct p{
int h,c,p,cidx;
} a[maxn];
bool cmp(p a,p b)
{
return a.h>b.h;
}
int main()
{
int n;
while(~scanf("%d",&n)){
for(int i=1;i<=n;i++){
scanf("%d%d%d",&a[i].h,&a[i].c,&a[i].p);
hh[i]=a[i].h;
cc[i]=a[i].c;
}
sort(hh+1,hh+n+1);
sort(cc+1,cc+n+1);
hm=unique(hh+1,hh+n+1)-hh-1; //高度离散化后数组大小
cm=unique(cc+1,cc+n+1)-cc-1; //花费
for(int i=1;i<=cm;i++)
num[i]=sum[i]=0;
for(int i=1;i<=hm;i++)
hnum[i]=0;
for(int i=1;i<=n;i++){
a[i].h=lower_bound(hh+1,hh+hm+1,a[i].h)-hh;
a[i].cidx=lower_bound(cc+1,cc+cm+1,a[i].c)-cc;
hnum[a[i].h]+=a[i].p; //相同高度的树的数目
add(a[i].cidx,a[i].p,a[i].c); //树状数组维护按价格为下标的数量前缀和和花费前缀和
}
sort(a+1,a+n+1,cmp); //按高度降序排序
int now=1;
ll mcost=0,ans=-1,cnt=0;
for(int i=1;i<=n;i++)
cnt+=a[i].p; //总数目
for(int i=hm;i>=1;i--){ //从最高处理到最低
ll tmp=0,num=hnum[i];
while(now<=n&&a[now].h==i){
tmp+=1ll*a[now].p*a[now].c; //需要被砍掉的
add(a[now].cidx,-a[now].p,a[now].c);
cnt-=a[now].p;
now++;
}
ll need=cnt-num+1; //要求大于一半
if(need<=0){ //不需要再砍
if(ans==-1||mcost<ans) ans=mcost;
mcost+=tmp;
continue;
}
int l=1,r=cm;
while(l<=r){ //二分需要砍掉的价格区间
int mid=(l+r)>>1;
ll num=suma(mid);
if(num>=need) r=mid-1;
else l=mid+1;
}
ll tback=suma(l)-need;
ll tans=sumb(l)+mcost-tback*cc[l];
if(ans==-1||tans<ans) ans=tans;
mcost+=tmp;
}
printf("%lld\n",ans);
}
return 0;
}
【代码2】
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
struct node{
int h,c,m;
}a[100002];
LL m[202];
bool cmp(node x,node y)
{
return x.h>y.h;
}
int main()
{
int n;
while(~scanf("%d",&n)){
memset(m,0,sizeof(m));
LL sum=0;
for(int i=1;i<=n;i++){
scanf("%d%d%d",&a[i].h,&a[i].c,&a[i].m);
m[a[i].c]+=a[i].m;
sum+=a[i].m;
}
sort(a+1,a+1+n,cmp);
LL ans=1e18,tmp=0;
int pos=1;
while(pos<=n){
LL cnt=a[pos].m;
m[a[pos].c]-=a[pos].m;
LL res=1LL*a[pos].c*a[pos].m;
sum-=a[pos].m;
while(pos+1<=n&&a[pos].h==a[pos+1].h){
pos++;
cnt+=a[pos].m;
m[a[pos].c]-=a[pos].m;
res+=1LL*a[pos].c*a[pos].m;
sum-=a[pos].m;
}
LL subx=0,mo=0;cnt--;
for(int i=1;i<=200;i++){
if(sum-subx<=cnt) break;
if(sum-subx-m[i]<=cnt){
mo+=1LL*i*(sum-subx-cnt);
break;
}else{
mo+=1LL*i*m[i];
subx+=m[i];
}
}
ans=min(ans,mo+tmp);
tmp+=res;
pos++;
}
printf("%lld\n",ans);
}
return 0;
}