文章目录
沧海的孤塔-chimera
题意:
n n n 个时间点,每个时间点有一个贡献 a i a_i ai ,从 n n n 个中选择 m m m 个出来,要求任意连续的 k k k 个时刻至少要选择一个,否则之后将无法再选择
求能取得的最大贡献。
思路:
明显的 d p dp dp ,但是有点点难写。
d p [ i ] [ j ] dp[i][j] dp[i][j] 表示前 i i i 个时间点中已经选了 j j j 个(并且包含第 i i i 个)。
转移方程: d p [ i ] [ j ] = m a x { d p [ x ] [ j − 1 ] + a [ i ] ∣ i − k ≤ x ≤ i − 1 } dp[i][j] = max\{dp[x][j - 1] + a[i]\ \ |\ \ i - k \leq x \leq i-1 \} dp[i][j]=max{dp[x][j−1]+a[i] ∣ i−k≤x≤i−1}
求最大值用线段树来维护一下, j j j 维度每完成一次,就以上一维的数重新建线段树得到最大值。
实现起来需要亿点点细节。
第一层循环是 j j j ,第二层循环是 i i i ,但 i i i 是从 j j j 开始的,因为前 i i i 个最多只能选择 i i i 个那么前 j − 1 j - 1 j−1 个就是 0 0 0 。然后 i i i 最多能到 m i n ( n , k ∗ j ) min(n, k * j) min(n,k∗j) ,因为每 k k k 个选择一个时间点,所以选择 j j j 个最多能到达 k ∗ j k*j k∗j 。但是这里要注意可能会超过 n n n 。
我在这里RE了两发最后查找答案时,需要从最后 k k k 个中取最大值,因为 d p dp dp 转移方程是从前往后一保证每 k k k 个至少有一个。查找时需要保证从后面开始也是满足条件的。因为 m ≥ n / k m \geq n/k m≥n/k ,即 n ≤ m k n \leq mk n≤mk ,所以肯定能找到满足条件的 m m m 个时间点。
好难写啊
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#include<cstring>
#include<algorithm>
#define fi first
#define se second
//#define int long long
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 4010;
const int M = 4010;
ll n, m, k;
ll tree[N << 2];
ll a[N];
ll tp[N];
#define ls (rt << 1)
#define rs ((rt << 1) | 1)
void build(ll l, ll r, ll rt) {
if (l == r) {
tree[rt] = tp[l];
return;
}
ll mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
tree[rt] = max(tree[ls], tree[rs]);
}
ll query(ll st, ll ed, ll l, ll r, ll rt) {
if (st > ed) return 0LL;
if (st <= l && r <= ed) return tree[rt];
ll mid = (l + r) >> 1;
ll ans = 0;
if (st <= mid) ans = max(ans, query(st, ed, l, mid, ls));
if (mid < ed) ans = max(ans, query(st, ed, mid + 1, r, rs));
return ans;
}
//void modify(int l, int r, int rt, int k, ll v) {
// if (l == r) {
// tree[rt] = v;
// return;
// }
// int mid = (l + r) >> 1;
// if (k <= mid) modify(l, mid, ls, k, v);
// else modify(mid + 1, r, rs, k ,v);
// tree[rt] = max(tree[ls], tree[rs]);
//}
ll dp[N][M];
int main() {
scanf("%lld%lld%lld", &n, &m, &k);
for (int i = 1; i <= n; i++)
scanf("%lld", &a[i]);
build(1, n, 1);
for (int j = 1; j <= m; j++) {
for (int i = 1; i <= n; i++) tp[i] = 0;
for (int i = j; i <= min(k * j, n); i++) {
dp[i][j] = query(max(i - k, 1LL), max(i - 1, 1), 1, n, 1) + a[i];
tp[i] = dp[i][j];
}
build(1, n, 1);
}
// for (int j = 1; j <= m; j++) {
// for (int i = 1; i <= n; i++) {
// printf("%lld ", dp[i][j]);
// }
// printf("\n");
// }
ll ans = 0;
for (int i = n - k + 1; i <= n; i++) {
ans = max(ans, dp[i][m]);
}
printf("%lld\n", ans);
return 0;
}