题目背景:
10.20 NOIP模拟T2 bzoj1084
分析:DP
听说这貌似是一道原题,听说我貌似8天前做过这道题,听说这道题的复杂度是O(nk),听说我8天前写的O(nk),然后今天写出了O(n3k)···(手动再见
首先我们来看这道题,m <= 2,这就是在告诉我们要在上面做文章。
首先考虑m = 1的情况,定义数组f[i][j]表示当前枚举到第i行,已经选择了j个矩阵,然后从前面0 ~ i - 1直接转移过来
f[i][j] = f[i - 1][j]
f[i][j] = max(f[l][j - 1] + sum[i] - sum[l])
(l >= 0 && l < i, sum为矩阵权值的前缀和)
这样转移的复杂度为O(n2k)
Source:
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_1() {
static int f[MAXN][12];
memset(f, 128, sizeof(f));
for (int i = 0; i <= n; ++i) f[i][0] = 0;
for (int c = 1; c <= k; ++c)
for (int i = 1; i <= n; ++i) {
f[i][c] = f[i - 1][c];
for (int l = 0; l < i; ++l)
f[i][c] = std::max(f[i][c], sum[0][i] - sum[0][l] +
f[l][c - 1]);
}
std::cout << f[n][k];
}
再来讲第二种方法,定义数组f[i][j][0/1]表示,当前枚举到第i行,已经选择了j个矩阵,当前的位置是否被选择了。(如果上一个被选择了,可以考虑直接将这一个接在上一个后面,矩阵数量不增加)
f[i][j][0] = max(f[i - 1][j][1], f[i - 1][j][0])
f[i][j][1] = max(f[i - 1][j - 1][0], f[i - 1][j - 1][1], f[i -1][j][1]) + a[i];
解释:f[i - 1][j -1][0] à 上一个位置没有被选择
f[i - 1][j - 1][1] à 上一个位置被选择了,这一个位置重新开始一个新的矩阵
f[i - 1][j][1] à 上一个位置被选择了,这个位置接在上一个所在的矩阵上面
这样做的复杂度是O(nk)的
Source:
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_1() {
static int f[MAXN][MAXK][2];
memset(f, 128, sizeof(f)), f[0][0][0] = 0;
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= k; ++j) {
f[i][j][0] = std::max(f[i - 1][j][0], f[i - 1][j][1]);
if (j > 0) f[i][j][1] = std::max(f[i - 1][j - 1][0],
f[i - 1][j - 1][1]) + a[0][i];
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);
}
std::cout << std::max(f[n][k][0], f[n][k][1]);
}
我们继续讲下一种,首先因为它只有一行,我们就相当于选择k个互不相交的子段,这是一种费用流的思想,然后我们再做分析发现,它的性质很特殊,我们可以直接利用线段树来维护,每一次选择一段区间相当于把这段区间里面的数据全部取反,这样的复杂度是O(klogn)
Source:
/*
created by scarlyw
*/
// 注:本代码为bzoj3502的代码,原题当中是选择最多k个子区间,
// 而本题则是要求至少k个,只需要把代码中写有注释的地方更改一下即可
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 1000000 + 10;
struct data {
int lp, rp, p1, p2;
long long lx, mx, rx, sum;
inline void init(int l, long long x) {
lp = rp = p1 = p2 = l, lx = mx = rx = sum = x;
}
friend inline data operator + (const data &a, const data &b) {
data t;
t.sum = a.sum + b.sum;
t.lp = a.lp, t.lx = a.lx;
if (t.lx < a.sum + b.lx) t.lx = a.sum + b.lx, t.lp = b.lp;
t.rp = b.rp, t.rx = b.rx;
if (t.rx < b.sum + a.rx) t.rx = b.sum + a.rx, t.rp = a.rp;
t.mx = a.rx + b.lx, t.p1 = a.rp, t.p2 = b.lp;
if (a.mx > t.mx) t.mx = a.mx, t.p1 = a.p1, t.p2 = a.p2;
if (b.mx > t.mx) t.mx = b.mx, t.p1 = b.p1, t.p2 = b.p2;
return t;
}
} ;
struct node {
data min, max;
bool flag;
inline void init(int l, long long val) {
min.init(l, -val), max.init(l, val);
}
} tree[(MAXN << 1) + 200000];
int n, k;
int a[MAXN];
inline void reverse(int k) {
std::swap(tree[k].min, tree[k].max), tree[k].flag ^= 1;
}
inline void push_down(int k) {
if (tree[k].flag) reverse(k << 1), reverse(k << 1 | 1), tree[k].flag ^= 1;
}
inline void update(int k) {
tree[k].min = tree[k << 1].min + tree[k << 1 | 1].min;
tree[k].max = tree[k << 1].max + tree[k << 1 | 1].max;
}
inline void build_tree(int k, int l, int r) {
if (l == r) return tree[k].init(l, a[l]);
int mid = l + r >> 1;
build_tree(k << 1, l, mid), build_tree(k << 1 | 1, mid + 1, r);
update(k);
}
inline void rever(int k, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return reverse(k);
push_down(k);
int mid = l + r >> 1;
if (ql <= mid) rever(k << 1, l, mid, ql, qr);
if (qr > mid) rever(k << 1 | 1, mid + 1, r, ql, qr);
update(k);
}
inline void solve() {
R(n), R(k);
for (int i = 1; i <= n; ++i) R(a[i]);
build_tree(1, 1, n);
long long ans = 0;
for (int i = 1; i <= k; ++i) {
data cur = tree[1].max;
if (cur.mx > 0) ans += cur.mx;
else break ;
//取消上面的if判断,直接修改成ans += cur.mx即可
rever(1, 1, n, cur.p1, cur.p2);
}
printf("%lld", ans);
}
int main() {
solve();
return 0;
}
然后我们再来看m = 2的情况。
先来讲讲暴力的O(n3k)的做法,虽然说得是这么高的复杂度,但是因为常数的确挺小,然后,最大数据也就跑了100ms多一些。定义数组f[i][j][k]表示当前第1列选到第i个数,第二列选到第j个数,已经选择了k个矩阵了,那么转移方程很好想。
f[i][j][k] = max(f[i - 1][j][k], f[i][j - 1][k])
f[i][j][k] = max(f[l][j][k - 1] + sum[0][i] - sum[0][l]) (l < i&& l >= 0) sum[0]为第一列的矩阵前缀和
f[i][j][k] = max(f[i][l][k - 1] + sum[1][j] - sum[1][l]) (l < j&& l >= 0) sum[1]为第二列的矩阵前缀和
如果i == j时,可以选择两列一起选择,所以
f[i][j][k] = max(f[l][l][k - 1] + sum[1][i] - sum[1][l] + sum[0][i] -sum[0][l]) (i == j && l < i && l >= 0)
这样DP的复杂度是O(n3k)的
Source:
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_2() {
static int f[MAXN][MAXN][12];
memset(f, 128, sizeof(f));
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= n; ++j)
f[i][j][0] = 0;
for (int c = 1; c <= k; ++c) {
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j) {
f[i][j][c] = std::max(f[i - 1][j][c], f[i][j - 1][c]);
for (int l = 0; l < i; ++l)
f[i][j][c] = std::max(f[i][j][c], f[l][j][c - 1] +
sum[0][i] - sum[0][l]);
for (int l = 0; l < j; ++l)
f[i][j][c] = std::max(f[i][j][c], f[i][l][c - 1] +
sum[1][j] - sum[1][l]);
if (i == j) {
for (int l = 0; l < i; ++l)
f[i][j][c] = std::max(f[i][j][c], f[l][l][c - 1] +
sum[0][i] - sum[0][l] + sum[1][i] - sum[1][l]);
}
}
}
std::cout << f[n][n][k];
}
我们再来看下一种方法,定义f[i][j][k]表示当前枚举到第i行,已经选择了j个矩阵,第i行的状态,第i行的状态一共有5种,分别是0:表示两个都不选,1:只选择第一列的,2:只选择第二列的,3:两列都选并且两列在不同的矩阵中,4:两列都选,并且两列在同一个矩阵中,然后分情况进行讨论及转移即可。复杂度为O(nk),常数略大。
转移直接见代码吧,比较清楚。
Source:
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_2() {
static int f[MAXN][MAXK][5];
/*
f[i][j][0] : i行,选择j个,本行两列均未选
f[i][j][1] : i行,选择j个,本行选择第一列
f[i][j][2] : i行,选择j个,本行选择第二列
f[i][j][3] : i行,选择j个,本行选择一二列,一二列不在同一矩阵中
f[i][j][4] : i行,选择j个,本行选择一二列,一二列在同一矩阵中
*/
memset(f, 128, sizeof(f)), f[0][0][0] = 0;
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= k; ++j) {
for (int t = 0; t < 5; ++t) {
f[i][j][0] = std::max(f[i][j][0], f[i - 1][j][t]);
if (j > 0) {
f[i][j][1] = std::max(f[i][j][1],
f[i - 1][j - 1][t] + a[0][i]);
f[i][j][2] = std::max(f[i][j][2],
f[i - 1][j - 1][t] + a[1][i]);
f[i][j][4] = std::max(f[i][j][4],
f[i - 1][j - 1][t] + a[0][i] + a[1][i]);
}
if (j > 1) f[i][j][3] = std::max(f[i][j][3],
f[i - 1][j - 2][t] + a[0][i] + a[1][i]) ;
}
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][3] + a[0][i]);
f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][2] + a[1][i]);
f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][3] + a[1][i]);
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j][3] +
a[0][i] + a[1][i]);
if (j > 0) {
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][1]
+ a[0][i] + a[1][i]);
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][2]
+ a[0][i] + a[1][i]);
}
f[i][j][4] = std::max(f[i][j][4], f[i - 1][j][4] +
a[0][i] + a[1][i]);
}
int ans = -INF;
for (int i = 0; i < 5; ++i) ans = std::max(f[n][k][i], ans);
std::cout << ans;
}
最后贴两份总的代码
m = 1部分用O(n2k)实现, m = 2部分用O(n3k)实现。
Source:
/*
created by scarlyw
*/
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <ctime>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
///*
template<class T>
inline void R(T &x) {
static bool iosig;
static char c;
for (iosig = false, c = read(); !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read()) x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
/*
template<class T>
inline void R(T &x) {
static bool iosig;
static char c;
for (iosig = false, c = getchar(); !isdigit(c); c = getchar()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = getchar()) x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 100 + 10;
int n, m, k;
int a[2][MAXN], sum[2][MAXN];
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_1() {
static int f[MAXN][12];
memset(f, 128, sizeof(f));
for (int i = 0; i <= n; ++i) f[i][0] = 0;
for (int c = 1; c <= k; ++c)
for (int i = 1; i <= n; ++i) {
f[i][c] = f[i - 1][c];
for (int l = 0; l < i; ++l)
f[i][c] = std::max(f[i][c], sum[0][i] - sum[0][l] +
f[l][c - 1]);
}
std::cout << f[n][k];
}
inline void solve_m_2() {
static int f[MAXN][MAXN][12];
memset(f, 128, sizeof(f));
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= n; ++j)
f[i][j][0] = 0;
for (int c = 1; c <= k; ++c) {
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j) {
f[i][j][c] = std::max(f[i - 1][j][c], f[i][j - 1][c]);
for (int l = 0; l < i; ++l)
f[i][j][c] = std::max(f[i][j][c], f[l][j][c - 1] +
sum[0][i] - sum[0][l]);
for (int l = 0; l < j; ++l)
f[i][j][c] = std::max(f[i][j][c], f[i][l][c - 1] +
sum[1][j] - sum[1][l]);
if (i == j) {
for (int l = 0; l < i; ++l)
f[i][j][c] = std::max(f[i][j][c], f[l][l][c - 1] +
sum[0][i] - sum[0][l] + sum[1][i] - sum[1][l]);
}
}
}
std::cout << f[n][n][k];
}
int main() {
// freopen("matrix.in", "r", stdin);
// freopen("matrix.out", "w", stdout);
read_in();
if (m == 1) solve_m_1();
else solve_m_2();
return 0;
}
m = 1,2部分均用O(nk)实现。
Source:
/*
created by scarlyw
*/
#include <cstdio>
#include <string>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <cmath>
#include <cctype>
#include <vector>
#include <set>
#include <queue>
inline char read() {
static const int IN_LEN = 1024 * 1024;
static char buf[IN_LEN], *s, *t;
if (s == t) {
t = (s = buf) + fread(buf, 1, IN_LEN, stdin);
if (s == t) return -1;
}
return *s++;
}
///*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = read(), iosig = false; !isdigit(c); c = read()) {
if (c == -1) return ;
if (c == '-') iosig = true;
}
for (x = 0; isdigit(c); c = read())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int OUT_LEN = 1024 * 1024;
char obuf[OUT_LEN], *oh = obuf;
inline void write_char(char c) {
if (oh == obuf + OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), oh = obuf;
*oh++ = c;
}
template<class T>
inline void W(T x) {
static int buf[30], cnt;
if (x == 0) write_char('0');
else {
if (x < 0) write_char('-'), x = -x;
for (cnt = 0; x; x /= 10) buf[++cnt] = x % 10 + 48;
while (cnt) write_char(buf[cnt--]);
}
}
inline void flush() {
fwrite(obuf, 1, oh - obuf, stdout);
}
/*
template<class T>
inline void R(T &x) {
static char c;
static bool iosig;
for (c = getchar(), iosig = false; !isdigit(c); c = getchar())
if (c == '-') iosig = true;
for (x = 0; isdigit(c); c = getchar())
x = ((x << 2) + x << 1) + (c ^ '0');
if (iosig) x = -x;
}
//*/
const int MAXN = 100 + 10;
const int MAXK = 12;
const int INF = 1000000000;
int n, m, k;
int a[2][MAXN], sum[2][MAXN];
inline void read_in() {
R(n), R(m), R(k);
for (int i = 1; i <= n; ++i)
for (int j = 0; j < m; ++j)
R(a[j][i]), sum[j][i] = sum[j][i - 1] + a[j][i];
}
inline void solve_m_1() {
static int f[MAXN][MAXK][2];
memset(f, 128, sizeof(f)), f[0][0][0] = 0;
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= k; ++j) {
f[i][j][0] = std::max(f[i - 1][j][0], f[i - 1][j][1]);
if (j > 0) f[i][j][1] = std::max(f[i - 1][j - 1][0],
f[i - 1][j - 1][1]) + a[0][i];
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);
}
std::cout << std::max(f[n][k][0], f[n][k][1]);
}
inline void solve_m_2() {
static int f[MAXN][MAXK][5];
memset(f, 128, sizeof(f)), f[0][0][0] = 0;
for (int i = 1; i <= n; ++i)
for (int j = 0; j <= k; ++j) {
for (int t = 0; t < 5; ++t) {
f[i][j][0] = std::max(f[i][j][0], f[i - 1][j][t]);
if (j > 0) {
f[i][j][1] = std::max(f[i][j][1],
f[i - 1][j - 1][t] + a[0][i]);
f[i][j][2] = std::max(f[i][j][2],
f[i - 1][j - 1][t] + a[1][i]);
f[i][j][4] = std::max(f[i][j][4],
f[i - 1][j - 1][t] + a[0][i] + a[1][i]);
}
if (j > 1) f[i][j][3] = std::max(f[i][j][3],
f[i - 1][j - 2][t] + a[0][i] + a[1][i]) ;
}
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][1] + a[0][i]);
f[i][j][1] = std::max(f[i][j][1], f[i - 1][j][3] + a[0][i]);
f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][2] + a[1][i]);
f[i][j][2] = std::max(f[i][j][2], f[i - 1][j][3] + a[1][i]);
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j][3] +
a[0][i] + a[1][i]);
if (j > 0) {
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][1]
+ a[0][i] + a[1][i]);
f[i][j][3] = std::max(f[i][j][3], f[i - 1][j - 1][2]
+ a[0][i] + a[1][i]);
}
f[i][j][4] = std::max(f[i][j][4], f[i - 1][j][4] +
a[0][i] + a[1][i]);
}
int ans = -INF;
for (int i = 0; i < 5; ++i) ans = std::max(f[n][k][i], ans);
std::cout << ans;
}
int main() {
read_in();
if (m == 1) solve_m_1();
else solve_m_2();
return 0;
}