洛谷 P2034 选择数字
Description
给定一行n个非负整数a[1]..a[n]。现在你可以选择其中若干个数,但不能有超过k个连续的数字被选择。你的任务是使得选出的数字的和最大。
Input
第一行两个整数n,k
以下n行,每行一个整数表示a[i]。
Output
输出一个值表示答案。
Sample Input
5 2
1
2
3
4
5
Sample Output
12
Data Size
对于20%的数据,n <= 10
对于另外20%的数据, k = 1
对于60%的数据,n <= 1000
对于100%的数据,1 <= n <= 100000,1 <= k <= n,0 <= 数字大小 <= 1,000,000,000
时间限制500ms
题解:
线性dp。
正解是逆向思维。把“选数”变为“删数”。然后用单调队列优化。O(n)可过。
我很菜,没想到这种方法。所以一开始写了一个二维的dp。
正向思维。dp(i, 0)表示前i个数中,不选第i这个数的最大和;dp(i, 1)表示前i个数中,选第i这个数的最大和。转移方程显然易见(详见代码)。但是复杂度明显O(nk),过不了。卡了常数后拿到了90pts:
#include
#include
#define N 1000005
#define LL long long
#define re register
using namespace std;
LL n, k, ans;
LL a[N], sum[N];
LL dp[N][2];
LL read()
{
LL x = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
return x;
}
int main()
{
cin >> n >> k;
for(re LL i = 1; i <= n; i++)
a[i] = read(), sum[i] = sum[i - 1] + a[i];
for(re LL i = 1; i <= n; i++)
{
for(re LL j = i - 1; j >= i - k && j >= 0; j--)
dp[i][0] = max(dp[i][0], max(dp[j][0], dp[j][1]));
for(re LL j = i - 1; j >= i - k && j >= 0; j--)
dp[i][1] = max(dp[i][1], dp[j][0] + sum[i] - sum[j]);
}
for(re LL i = 1; i <= n; i++) ans = max(ans, max(dp[i][0], dp[i][1]));
cout << ans;
return 0;
}
TLE的问题出在哪里呢?
出在代码中的这一段:
for(int j = i - 1; j >= i - k && j >= 0; j--)
dp(i, 0) = max(dp(i, 0), max(dp(j, 0), dp(j, 1)));
for(int j = i - 1; j >= i - k && j >= 0; j--)
dp(i, 1) = max(dp(i, 1), dp(j, 0) + sum[i] - sum[j]);
可以发现,找max的过程可以用线段树维护。
具体就是开两个线段树。一个线段树维护每个位置dp(j, 0)和dp(j, 1)的最值。另一个线段树维护dp(j, 0) - sum[j]的最值。(因为sum[i]是定值,故可不用维护)
那么上面的代码就可以改写成这样:
for(int i = 1; i <= n; i++)
{
int minn = max(i - k, 0), v1, v2;
v1 = ask1(1, minn, i - 1);
dp(i, 1) = v1 + sum[i];
v2 = ask2(1, minn, i - 1);
dp(i, 0) = v2;
update1(1, i, i, dp(i, 0) - sum[i]);
update2(1, i, i, max(dp(i, 1), dp(i, 0)));
}
复杂度O(nlogn)。可过。
#include
#include
#define N 1000005
#define LL long long
using namespace std;
struct Tree {LL l, r, val, tag;} tree1[N * 4], tree2[N * 4];
LL n, k;
LL a[N], sum[N];
LL dp[N][2];
LL read()
{
LL x = 0; char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
return x;
}
void build(LL root, LL l, LL r)
{
tree1[root].l = l, tree1[root].r = r;
tree2[root].l = l, tree2[root].r = r;
if(l == r) return;
LL mid = (l + r) >> 1;
build(root << 1, l, mid);
build(root << 1 | 1, mid + 1, r);
}
void down1(LL root)
{
LL son1 = root << 1, son2 = root << 1 | 1;
tree1[son1].tag += tree1[root].tag, tree1[son2].tag += tree1[root].tag;
tree1[son1].val += tree1[root].tag, tree1[son2].tag += tree1[root].tag;
tree1[root].tag = 0;
}
void down2(LL root)
{
LL son1 = root << 1, son2 = root << 1 | 1;
tree2[son1].tag += tree2[root].tag, tree2[son2].tag += tree2[root].tag;
tree2[son1].val += tree2[root].tag, tree2[son2].tag += tree2[root].tag;
tree2[root].tag = 0;
}
LL ask1(LL root, LL l, LL r)
{
if(tree1[root].l >= l && tree1[root].r <= r) return tree1[root].val;
if(tree1[root].tag) down1(root);
LL mid = (tree1[root].l + tree1[root].r) >> 1;
if(l <= mid && r > mid) return max(ask1(root << 1, l, r), ask1(root << 1 | 1, l, r));
else if(l <= mid) return ask1(root << 1, l, r);
else if(r > mid) return ask1(root << 1 | 1, l, r);
}
LL ask2(LL root, LL l, LL r)
{
if(tree2[root].l >= l && tree2[root].r <= r) return tree2[root].val;
if(tree2[root].tag) down2(root);
LL mid = (tree2[root].l + tree2[root].r) >> 1;
if(l <= mid && r > mid) return max(ask2(root << 1, l, r), ask2(root << 1 | 1, l, r));
else if(l <= mid) return ask2(root << 1, l, r);
else if(r > mid) return ask2(root << 1 | 1, l, r);
}
void update1(LL root, LL l, LL r, LL add)
{
if(tree1[root].l >= l && tree1[root].r <= r)
{
tree1[root].tag += add, tree1[root].val += add;
return;
}
if(tree1[root].tag) down1(root);
LL mid = (tree1[root].l + tree1[root].r) >> 1;
if(l <= mid) update1(root << 1, l, r, add);
if(r > mid) update1(root << 1 | 1, l, r, add);
tree1[root].val = max(tree1[root << 1].val, tree1[root << 1 | 1].val);
}
void update2(LL root, LL l, LL r, LL add)
{
if(tree2[root].l >= l && tree2[root].r <= r)
{
tree2[root].tag += add, tree2[root].val += add;
return;
}
if(tree2[root].tag) down2(root);
LL mid = (tree2[root].l + tree2[root].r) >> 1;
if(l <= mid) update2(root << 1, l, r, add);
if(r > mid) update2(root << 1 | 1, l, r, add);
tree2[root].val = max(tree2[root << 1].val, tree2[root << 1 | 1].val);
}
int main()
{
cin >> n >> k;
for(LL i = 1; i <= n; i++)
a[i] = read(), sum[i] = sum[i - 1] + a[i];
build(1, 0, n);
for(LL i = 1; i <= n; i++)
{
LL minn = max(i - k, (LL)0), v1, v2;
v1 = ask1(1, minn, i - 1);
dp[i][1] = v1 + sum[i];
v2 = ask2(1, minn, i - 1);
dp[i][0] = v2;
update1(1, i, i, dp[i][0] - sum[i]);
update2(1, i, i, max(dp[i][1], dp[i][0]));
}
cout << ask2(1, 1, n);
return 0;
}