题意描述
给你一个长度为 n n n 的序列,执行一次以下操作:
选 k k k 个数 + x +x +x,其余未选择的数 − x -x −x。
求操作后的最大子段和。
简要分析
观察到要求最大子段和,但发现需要对数组进行操作,于是考虑如何改编最大字段和。
这里我们改变使用分治(线段树)的方法求解。
关于如何使用分治(线段树)的方法求解最大子段和这里不再赘述。
观察到在建树时递归到最底端,也就是 l = = r l == r l==r 是,此时维护的最大字段和、以左端点为首元素的最大子段和,以右端点为尾元素的最大子段和均等于 a l a_l al 即 a r a_r ar。
我们可以在这里分类讨论,如果 k = 0 k = 0 k=0 那么只需将 a l − x a_l - x al−x 即 a r − x a_r - x ar−x 即可。
若 k ≠ 0 k \not= 0 k=0 要么将 a l + x a_l + x al+x 即 a r + x a_r + x ar+x 要么将 a l − x a_l - x al−x 即 a r − x a_r -x ar−x。
我们只需讲这两种情况均与另一节点合并至父节点中即可。
合并时仅需枚举配对情况取最大值即可。
事实上,这里看似有 2 x 2^x 2x 种合并方式,然而我们只需其中的 k k k 种。
因为有和多种情况所算出来的答案是相同的。
又因为我们不需要知道怎么选,只需知道最大子段和的值。
所以可以将 2 x 2^x 2x 种合并方式简化为 k k k 种合并方式。
又因为此题仅需一次建树一次查询所以可以省去许多冗余的函数,代码十分简洁。
时间复杂度为 O ( n k ) O(nk) O(nk)。
代码实现
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <cmath>
#include <vector>
using namespace std;
typedef long long ll;
const ll maxn = 2e5 + 7;
const ll INF = 1e18 + 7, MOD = 998244353;
inline ll read() {
char cCc;
ll xXx = 0, wWw = 1;
while (cCc < '0' || cCc > '9')
(cCc == '-') && (wWw = -wWw), cCc = getchar();
while (cCc >= '0' && cCc <= '9')
xXx = (xXx << 1) + (xXx << 3) + (cCc ^ '0'), cCc = getchar();
xXx *= wWw;
return xXx;
}
inline void write(ll xXx) {
if (xXx < 0)
putchar('-'), xXx = -xXx;
if (xXx > 9)
write(xXx / 10);
putchar(xXx % 10 + '0');
}
struct node {
ll sum, lq, rq, qs;
node() {
sum = -INF, lq = -INF, rq = -INF, qs = -INF;
}
node(ll sum, ll lq, ll rq, ll qs) : sum(sum), lq(lq), rq(rq), qs(qs) {}
};
ll n, k, x, a[maxn];
vector<node> solve(ll l, ll r) {
if (l == r) {
ll s1 = max(0ll, a[r] + x);
ll s2 = max(0ll, a[r] - x);
if (k == 0) return {node(a[r] - x, s2, s2, s2)};
return {node(a[r] - x, s2, s2, s2), node(a[r] + x, s1, s1, s1)};
}
ll mid = (l + r) / 2;
auto vl = solve(l, mid), vr = solve(mid + 1, r);
ll len = min(k, r - l + 1) + 1;
vector<node> ans(len);
for (ll i = 0; i < vl.size(); i++) {
for (ll j = 0; j < vr.size() && i + j < len; j++) {
auto &l = vl[i], &r = vr[j], &u = ans[i + j];
u.sum = max(u.sum, l.sum + r.sum);
u.lq = max(u.lq, l.lq);
u.lq = max(u.lq, l.sum + r.lq);
u.rq = max(u.rq, r.rq);
u.rq = max(u.rq, l.rq + r.sum);
u.qs = max(u.qs, l.qs);
u.qs = max(u.qs, r.qs);
u.qs = max(u.qs, l.rq + r.lq);
}
}
return ans;
};
signed main() {
ll T = read();
while (T--) {
n = read(), k = read(), x = read();
for (ll i = 1; i <= n; i++) a[i] = read();
cout << solve(1, n)[k].qs << '\n';
}
}