Find the answer 权值线段树
题意:给n个数字,m值,输出n个值,每次问对于前缀和1~i(1<=i<=n) 中,最少去掉(1,i-1)中多少个数字才能使前缀和小于m;
思路:很多做法,可以使用权值线段树,用权值线段树记录数组的前缀和和前缀和所对应得数字个数,因为权值线段树中叶子节点记录的数字是从小到大的,所以我们可以求出权值线段树前缀和小于一个值所能最多添加多少个数字(优先取左子树的数字),
也就是每次求最多有多少个数字的前缀和小于m-a[i]求出ans,然后i-1-ans就是所需减掉的最小个数
#include <bits/stdc++.h> using namespace std; #define ll long long const int maxn=2e5+5; ll sum[maxn*4]; int num[maxn*4]; // sum记录前缀和 num记录个数 int a[maxn]; int b[maxn]; void build(int root,int l,int r) { sum[root]=0; num[root]=0; if(l==r) return; int mid=(l+r)/2; build(root<<1,l,mid); build(root<<1|1,mid+1,r); } void update(int root,int l,int r,int index) { if(l==r) { sum[root]+=b[l]; num[root]++; return; } int mid=(l+r)/2; if(index<=mid) update(root<<1,l,mid,index); else update(root<<1|1,mid+1,r,index); sum[root]=sum[root<<1]+sum[root<<1|1]; num[root]=num[root<<1]+num[root<<1|1]; } int query(int root,int l,int r,int k) // 查询最多有多少个数字和<=k { if(l==r) { // return k/b[l]; if(k == 0) return 0; else if(k >= b[l]) { return k / b[l]; } return 0; } int mid=(l+r)/2; ll Lsum=sum[root<<1]; if(Lsum>=k) return query(root<<1,l,mid,k); return num[root<<1]+query(root<<1|1,mid+1,r,k-Lsum); } int main() { int T; scanf("%d",&T); while(T--) { ll n,m; scanf("%lld%lld",&n,&m); for(int i = 1; i <= n; ++i) { scanf("%d",&a[i]); b[i]=a[i]; } sort(b+1,b+1+n); int cnt=unique(b+1,b+1+n)-b-1; build(1,1,n); ll ret=0; for(int i=1; i<=n; i++) { ret+=a[i]; if(ret<=m) printf("0 "); else { ll tmp=m-a[i]; int ans=query(1,1,n,tmp); printf("%d ",i-1-ans); } int pos=lower_bound(b+1,b+1+cnt,a[i])-b; update(1,1,n,pos); } printf("\n"); } return 0; }
做法2 树状数组+二分
思路:对n个数字离散化,注意这里不能去重,之后用两个树状数组,一个维护目前的前缀和,一个维护前缀和对应得数字个数,一次查询复杂度为log n,
#include<bits/stdc++.h> using namespace std; const int maxn=2e5+10; #define ll long long struct note { int id,num; } b[maxn]; ll sum[maxn],num[maxn]; int n,m; int a[maxn]; int pos[maxn]; ll asksum(int x) { ll ans=0; for(; x; x-=x&-x) ans+=sum[x]; return ans; } void addsum(int x,int y) { for(; x<=n; x+=x&-x) sum[x]+=y; } ll asknum(int x) { ll ans=0; for(; x; x-=x&-x) ans+=num[x]; return ans; } void addnum(int x,int y) { for(; x<=n; x+=x&-x) num[x]+=y; } int cmp(note a,note b) { return a.num<b.num; } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d%d",&n,&m); for(int i=1; i<=n; i++) sum[i]=num[i]=0; for(int i=1; i<=n; i++) { scanf("%d",&a[i]); b[i].id=i; b[i].num=a[i]; } sort(b+1,b+1+n,cmp); for(int i=1; i<=n; i++) pos[b[i].id]=i; ll ret=0; for(int i=1; i<=n; i++) { ret+=a[i]; if(ret<=m) printf("0 "); else { ll temp=m-a[i]; int l,r,mid; l=0,r=n+1; while(l<r) { mid=(l+r+1)/2; if(asksum(mid)<=temp) l=mid; else r=mid-1; } printf("%d ",i-1-asknum(l)); } addnum(pos[i],1); addsum(pos[i],a[i]); } printf("\n"); } }