题目大意:给定一个序列,求出k个这个序列的位置不完全相同的子序列,使得每一个子序列的长度均在[l,r]内,并且使得这些子序列的权值和最大。
思路:每一个子序列的权值和可以转化为两个前缀和之差。我们考虑以每一个位置为结尾的子序列,它的权值和可以看作是以该位置为结尾的前缀和减去它前面的某个前缀和。
那么想要这个子序列的权值和尽量大,那么就要前面的那个前缀和尽可能小。如果数目不够,就第2小。再不够,就第3小。
于是我们维护一个全局堆,分别表示以每个位置为结尾的最大子序列权值和,每次取出堆顶时再放进堆中一个结尾位置相同的第k+1大的权值和。
这样k次就能出解。
这要求我们能够快速求出区间第k小,利用可持久化线段树即可。
时间复杂度O(klogn).
Code:
#include <map>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define N 500010
int w[N];
int Uni[N], sav[N], id;
map<int, int> M;
#define l(x) S[x].l
#define r(x) S[x].r
#define size(x) S[x].size
struct Node {
int l, r, size;
}S[10000000];
int ind, root[N];
int Newadd(int Last, int tl, int tr, int ins, int add) {
int q = ++ind;
S[q] = S[Last];
size(q) += add;
if (tl == tr)
return q;
int mid = (tl + tr) >> 1;
if (ins <= mid)
l(q) = Newadd(l(Last), tl, mid, ins, add);
else
r(q) = Newadd(r(Last), mid + 1, tr, ins, add);
return q;
}
inline int kth(int Left, int Right, int tl, int tr, int k) {
int mid, ls;
while(1) {
if (tl == tr)
return tl;
mid = (tl + tr) >> 1;
if ((ls = size(l(Right)) - size(l(Left))) >= k)
Left = l(Left), Right = l(Right), tr = mid;
else
Left = r(Left), Right = r(Right), tl = mid + 1, k -= ls;
}
return -1;
}
struct State {
int end, kth, val;
State(int _end = 0, int _kth = 0, int _val = 0):end(_end),kth(_kth),val(_val){}
bool operator < (const State &B) const {
return val > B.val;
}
};
struct Heap {
State a[N];
int top;
Heap():top(0){}
inline void up(int x) {
for(; x != 1; x >>= 1)
if (a[x] < a[x>>1])
swap(a[x], a[x>>1]);
else
break;
}
inline void down(int x) {
int son;
for(; (x<<1) <= top;) {
son = (((x<<1)==top)||(a[x<<1]<a[(x<<1)|1]))?(x<<1):((x<<1)|1);
if (a[son]<a[x])
swap(a[son],a[x]),x=son;
else
break;
}
}
void push(const State &x) {
a[++top] = x;
up(top);
}
State Max() {
return a[1];
}
void pop() {
a[1] = a[top--];
down(1);
}
}H;
int insl[N], insr[N];
int main() {
#ifndef ONLINE_JUDGE
freopen("tt.in", "r", stdin);
#endif
int n, k, L, R;
scanf("%d%d%d%d", &n, &k, &L, &R);
register int i;
for(i = 1; i <= n; ++i) {
scanf("%d", &w[i]);
w[i] += w[i - 1];
}
for(i = 0; i <= n; ++i)
Uni[i + 1] = w[i];
sort(Uni + 1, Uni + n + 2);
Uni[0] = -1 << 30;
for(i = 1; i <= n + 1; ++i)
if (Uni[i] != Uni[i - 1])
sav[++id] = Uni[i], M[Uni[i]] = id;
for(i = 0; i <= n; ++i)
w[i] = M[w[i]];
root[0] = Newadd(0, 1, id, w[0], 1);
for(i = 1; i <= n; ++i)
root[i] = Newadd(root[i - 1], 1, id, w[i], 1);
for(i = L; i <= n; ++i) {
insl[i] = max(i - R, 0);
insr[i] = i - L;
H.push(State(i, 1, sav[w[i]] - sav[kth((insl[i] == 0) ? 0 : root[insl[i] - 1], root[insr[i]], 1, id, 1)]));
}
long long res = 0;
for(i = 1; i <= k; ++i) {
State tmp = H.Max();
H.pop();
res += tmp.val;
if (tmp.kth != insr[tmp.end] - insl[tmp.end] + 1)
H.push(State(tmp.end, tmp.kth + 1, sav[w[tmp.end]] - sav[kth(insl[tmp.end] == 0 ? 0 : root[insl[tmp.end] - 1], root[insr[tmp.end]], 1, id, tmp.kth + 1)]));
}
printf("%lld", res);
return 0;
}