题意:
给出一段长度为n的序列w[1]、w[2]、… 、w[n](1 ≤ w ≤ m),对于所有 i ∈[1, n ],最少 要令多少 w[j] (其中 j ∈[1, i-1])变为0,才能使得 ∑ k = 1 i w [ k ] ⩽ m \sum_{k=1}^{i}w[k]\leqslant m k=1∑iw[k]⩽m
分析:
对w[i],先求一下前缀和sum,首先如果满足sum≤m,则ans[i]=0;
当sum>m,我们就要删除1 ~ i-1的元素(变为0),为了删除的个数最少,很明显要优先删除大的元素,但暴力排序肯定不行,这里就要构建一棵权值线段树。
把w离散化后构建权值线段树,其叶结点的从左到右依次编号就代表权值,但里面存储的不是该权值的个数,而是总和,例如权值为a,有b个,那么就是a*b,pushup时也是相加求和。
查询的时候就是查找m-sum最少是多少个权值相加,利用二分的思想在线段树上查找,大的值优先所以优先查找右子树。(锁定分界点位置后依然需要用记录个数的权值线段树得到ans[i])
具体看代码注释。
以下代码:
#include<bits/stdc++.h>
#define LL long long
#define PLL pair<LL,LL>
using namespace std;
const int INF=0x3f3f3f3f;
const int maxn=2e5+50;
int n,N;
LL m,w[maxn],b[maxn],s[maxn<<2],c[maxn<<2];
void init()
{
sort(b+1,b+n+1);
N=unique(b+1,b+n+1)-(b+1);
for(int i=1;i<=n;i++)
w[i]=lower_bound(b+1,b+N+1,w[i])-b;
}
void updata(int rt,int l,int r,int p)
{
if(l==r)
{
s[rt]+=b[l]; //记录权值p的总和
c[rt]++; //记录权值p的个数
return;
}
int mid=(l+r)>>1;
if(p<=mid)
updata(rt<<1,l,mid,p);
else
updata(rt<<1|1,mid+1,r,p);
s[rt]=s[rt<<1]+s[rt<<1|1];
c[rt]=c[rt<<1]+c[rt<<1|1];
}
PLL query1(int rt,int l,int r,LL k) //二分找到需删除权值的分界点
{
if(l==r)
{
if(k%b[l]==0)
return PLL(l,k/b[l]); //返回分界点权值
else //以及需要删除的该分界点权值的具体个数
return PLL(l,k/b[l]+1);
}
int mid=(l+r)>>1;
if(s[rt<<1|1]>=k) //若右子树总和大于k,往右子树找
return query1(rt<<1|1,mid+1,r,k);
else
return query1(rt<<1,l,mid,k-s[rt<<1|1]); //否则往左子树找,注意k要减去右子树的权值
//(说明右子树的元素都要删除)
}
LL query2(int rt,int l,int r,int ql,int qr) //查询个数
{
if(ql<=l&&r<=qr)
return c[rt];
if(r<ql||l>qr)
return 0;
int mid=(l+r)>>1;
return query2(rt<<1,l,mid,ql,qr)+query2(rt<<1|1,mid+1,r,ql,qr);
}
int main()
{
int Q;
scanf("%d",&Q);
while(Q--)
{
scanf("%d %lld",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%lld",&w[i]);
b[i]=w[i];
}
init(); //离散化
memset(s,0,sizeof(s));
memset(c,0,sizeof(c));
LL sum=0,ans[maxn];
for(int i=1;i<=n;i++)
{
sum+=b[w[i]];
if(sum>m)
{
PLL temp=query1(1,1,N,sum-m);
ans[i]=temp.second; //需删除的分界点权值的具体个数
if(temp.first+1<=m)
ans[i]+=query2(1,1,N,temp.first+1,m); //大于分界点权值的元素个数
}
else
ans[i]=0;
updata(1,1,N,w[i]);
}
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
printf("\n");
}
return 0;
}