题目
思路
-
问题转换:满足规则的区间元素取2倍,其他区间取1倍,故原问题转换为 取 与 不取
-
关键:如何遍历? 遍历a序列,构建状态空间树,采用bfs的策略,依据最近遍历的连续取或者连续不取的数量对中间状态进行分类(即分为2k+1类),对于每个分类而言,贪心成立,因此可持续优化,删减中间状态数量。遍历过程如下:
-
前k+1个元素不能取
-
从第k+2个元素开始,有取与不取两种选择
-
连续选择“取”次数小于k次的有取与不取两种选择,但连续取 k 次后,下一次不能取,
-
连续k+1次不选后,下一次有取与不取两种选择
-
总结:一般而言,访问一个元素后,有 2k+1 个状态作为访问下一个元素的初始状态
-
图中,下面的k+1个连续不取的状态用数组维护
-
图中,上面的k个连续取的状态维护想到了2种策略:单调队列或者线段树
-
单调队列:k个连续取的状态中,贡献值小于等于 “最后连续1个取” 的状态舍弃,同时转换操作只需关注单调队列的队首(最大值)即可
-
线段树:在遍历a序列的过程种,k个状态在不断变化,把这k个状态看作是一个序列中宽度为k的一个区间,每次向后移动1个位置,有更新和查询2种操作。因此是线段树模型。
代码
单调队列
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define MAXN 100010
LL a[MAXN], b[MAXN], c[MAXN], ans, delt;
deque<int> q;
int main(){
int n, k;
scanf("%d %d", &n, &k);
ans = 0;
for(int i = 1; i <= n; i++){
scanf("%lld", &a[i]);
ans += a[i];
}
delt = 0, c[2] = 0, q.push_back(2);
for(int i = k+2; i <= n; i++){
int f = q.front();
if(f + k == i) q.pop_front();
b[i] = c[f] + delt;
b[i-k] = max(b[i-k], b[i-k-1]);
c[i] = b[i-k-1] - delt;
delt += a[i];
while(q.size() && c[q.back()] <= c[i]) q.pop_back();
q.push_back(i);
}
for(int i = n-k+1; i <= n; i ++) b[n-k] = max(b[n-k], b[i]);
ans += max(delt + c[q.front()], b[n-k]);
printf("%lld\n", ans);
return 0;
}
线段树
#include<bits/stdc++.h>
using namespace std;
#define MAXN 100010
struct sTN{
int l, r;
long long m, v;
}tr[MAXN<<2];
long long bk[MAXN], ans;
int a[MAXN], n, k, pb;
void init();
void build_tree(int, int, int);
void add(int, int, long long, int);
long long query(int, int, int);
void push(int);
void up(int);
int main(){
long long tmp, tmp2;
scanf("%d %d", &n, &k);
for(int i = 1; i <= n; i++) scanf("%d", a+i);
init();
for(int i = k+2; i <= n; i++){
tmp = query(i-k-1, i-1, 1);
ans = max(ans, tmp);
add(i, i, bk[pb], 1);
add(i-k+1, i, a[i], 1);
tmp2 = bk[pb];
bk[pb] = tmp; // pb指向k+1个空
pb = (pb + 1) % (k+1); // pb指向1个空
bk[pb] = max(bk[pb], tmp2);
}
ans = max(ans, query(n-k+1, n, 1));
for(int i = 1; i <= n; i++) ans += a[i];
printf("%lld\n", ans);
return 0;
}
void init(){
pb = 0, ans = 0;
memset(bk, 0, sizeof bk);
memset(tr, 0, sizeof tr);
build_tree(1, n, 1);
memset(bk, 0, sizeof bk);
}
void build_tree(int l, int r, int root){
sTN & rt = tr[root];
rt.l = l, rt.r = r, rt.m = 0, rt.v = 0;
if(l == r) return;
int mid = (l+r)>>1;
build_tree(l, mid, root<<1);
build_tree(mid+1, r, (root<<1)+1);
}
void add(int l, int r, long long data, int root){
sTN & rt = tr[root];
if(l == rt.l && r == rt.r) {
rt.m += data, rt.v += data;
if(root > 1) up(root>>1);
return;
}
if(rt.v > 0) push(root);
int mid = (rt.l+rt.r) >> 1;
if(r <= mid)
add(l, r, data, root << 1);
else if(mid+1 <= l)
add(l, r, data, (root<<1)+1);
else{
add(l, mid, data, root << 1);
add(mid+1, r, data, (root<<1)+1);
}
if(root > 1) up(root>>1);
}
long long query(int l, int r, int root){ // 查询最大值
sTN & rt = tr[root];
if(l == rt.l && r == rt.r) return rt.m;
if(rt.v > 0) push(root);
int mid = (rt.l+rt.r) >> 1;
if(r <= mid)
return query(l, r, root << 1);
else if(mid+1 <= l)
return query(l, r, (root<<1)+1);
else
return max(query(l, mid, root << 1), query(mid+1, r, (root<<1)+1));
}
void push(int root){
sTN & rt = tr[root];
if(rt.l == rt.r) return;
sTN &lrt = tr[root<<1], &rrt = tr[(root<<1)+1];
lrt.v += rt.v, rrt.v += rt.v;
lrt.m += rt.v, rrt.m += rt.v;
rt.v = 0;
}
void up(int root){
if(root <= 1) return;
sTN & rt = tr[root], &lrt = tr[root<<1], &rrt = tr[(root<<1)+1];
int mid = (rt.l+rt.r) >> 1;
rt.m = max(lrt.m, rrt.m);
}