题目概述
给出一个序列,选出 k k 个长度在 的子段(不可选重),求 k k 个子段的和的最大值。
解题报告
2017.8.8Update:好像因为这种贪心套路特别经典就被冠名超级钢琴了233333。
如果我们把所有长度在 的子段都处理出来并从大到小排序,那么根据贪心,肯定选前 k k 个最优秀。
但我们不可能把所有满足要求的子段都处理出来:太多了。需要注意到的是, 并不是很大,所以我们要想办法每次都选最大的满足要求的子段, 选 k k 次累加起来就是答案。
定义 ,即以 i i 为右端点,左端点范围在 之间的最大子段, sum s u m 是前缀和。我们考虑一个位置 i i , 是以 i i 为右端点的初始最优解,那么从所有初始最优解中刷出最大的就是第一大的满足子段。
假设第一大的三元组是 ,最优解位置在 t t ,那么由于 已经被选了,所以 [l,r] [ l , r ] 被拆成了 [l,t−1] [ l , t − 1 ] 和 [t+1,r] [ t + 1 , r ] ,把 (i,l,t−1) ( i , l , t − 1 ) 和 (i,t+1,r) ( i , t + 1 , r ) 加入待选三元组。不停地从待选三元组中选出 MAX M A X 最大的三元组,并加入新产生的三元组,选 k k 次即可。
已经有了想法,接下来我们只需要解决两个问题即可:
- 如何快速求出 :由于 i i 固定,所以用ST算法预处理区间最小值即可。
- 如何快速选出最大的 :用堆即可。
示例程序
#include<cstdio> #include<cmath> #include<queue> using namespace std; typedef long long LL; const int maxn=500000,Log=19; int n,K,L,R,sum[maxn+5],RMQ[maxn+5][Log+5]; LL ans; inline bool Eoln(char ch) {return ch==10||ch==13||ch==EOF;} inline char readc() { static char buf[100000],*l=buf,*r=buf; if (l==r) r=(l=buf)+fread(buf,1,100000,stdin); if (l==r) return EOF; else return *l++; } inline int readi(int &x) { int tot=0,f=1;char ch=readc(),lst='+'; while ('9'<ch||ch<'0') {if (ch==EOF) return EOF;lst=ch;ch=readc();} if (lst=='-') f=-f; while ('0'<=ch&&ch<='9') tot=tot*10+ch-48,ch=readc(); return x=tot*f,Eoln(ch); } int Miner(int i,int j) {if (sum[i-1]<sum[j-1]) return i; else return j;} void make_RMQ() { for (int j=1,k=log2(n);j<=k;j++) for (int i=1;i<=n-(1<<j)+1;i++) RMQ[i][j]=Miner(RMQ[i][j-1],RMQ[i+(1<<j-1)][j-1]); } int Ask(int L,int R) {int j=log2(R-L+1);return Miner(RMQ[L][j],RMQ[R-(1<<j)+1][j]);} struct data { int i,L,R,t; data(int a,int b,int c,int d) {i=a;L=b;R=c;t=d;} bool operator < (const data &c) const {return sum[i]-sum[t-1]<sum[c.i]-sum[c.t-1];} }; priority_queue<data> Heap; int main() { freopen("program.in","r",stdin); freopen("program.out","w",stdout); readi(n);readi(K);readi(L);readi(R); for (int i=1,x;i<=n;i++) readi(x),sum[i]=sum[i-1]+x,RMQ[i][0]=i;make_RMQ(); while (!Heap.empty()) Heap.pop(); for (int i=L;i<=n;i++) { int l=max(i-R+1,1),r=i-L+1; Heap.push(data(i,l,r,Ask(l,r))); } while (K--) { data now=Heap.top();Heap.pop(); ans+=sum[now.i]-sum[now.t-1]; if (now.t>now.L) Heap.push(data(now.i,now.L,now.t-1,Ask(now.L,now.t-1))); if (now.t<now.R) Heap.push(data(now.i,now.t+1,now.R,Ask(now.t+1,now.R))); } printf("%lld\n",ans); return 0; }