题目描述
设有 n × m n \times m n×m 的方格图,每个方格中都有一个整数。现有一只小熊,想从图的左上角走到右下角,每一步只能向上、向下或向右走一格,并且不能重复经过已经走过的方格,也不能走出边界。小熊会取走所有经过的方格中的整数,求它能取到的整数之和的最大值。
输入格式
第一行有两个整数 n , m n, m n,m。
接下来 n n n 行每行 m m m 个整数,依次代表每个方格中的整数。
输出格式
一个整数,表示小熊能取到的整数之和的最大值。
样例 #1
样例输入 #1
3 4
1 -1 3 2
2 -1 4 -1
-2 2 -3 -1
样例输出 #1
9
样例 #2
样例输入 #2
2 5
-1 -1 -3 -2 -7
-2 -1 -4 -1 -2
样例输出 #2
-10
提示
样例 1 解释
样例 2 解释
数据规模与约定
- 对于 20 % 20\% 20% 的数据, n , m ≤ 5 n, m \le 5 n,m≤5。
- 对于 40 % 40\% 40% 的数据, n , m ≤ 50 n, m \le 50 n,m≤50。
- 对于 70 % 70\% 70% 的数据, n , m ≤ 300 n, m \le 300 n,m≤300。
- 对于 100 % 100\% 100% 的数据, 1 ≤ n , m ≤ 1 0 3 1 \le n,m \le 10^3 1≤n,m≤103。方格中整数的绝对值不超过 1 0 4 10^4 104。
思路
dp,定义dp[i][j][0/1]表示从点i,j,方向向下/上,所能达到的最大值。
转移方程有3个方向
- 从相邻左列,转移到当前列。
- 从相邻上一行,转移到当前行。
- 从相邻下一行,转移到当前行。
最后,答案为max(dp[n][m][0],dp[n][m][1])
代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pcc pair<char, char>
#define inf 0x3f3f3f3f
const int maxn = 1010;
int n, m;
int a[maxn][maxn];
// dir: 0-> down 1-> up
ll dp[maxn][maxn][2];
bool vis[maxn][maxn][2];
void solve() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) {
scanf("%d", &a[i][j]);
}
}
memset(vis, 0, sizeof(vis));
dp[0][0][0] = dp[0][0][1] = a[0][0];
vis[0][0][0] = vis[0][0][1] = 1;
for (int i = 1; i < n; ++i) {
if (vis[i-1][0][0]) {
dp[i][0][0] = a[i][0] + dp[i-1][0][0];
vis[i][0][0] = 1;
}
}
for (int j = 1; j < m; ++j) {
// from right direction
for (int i = 0; i < n; ++i) {
dp[i][j][0] = dp[i][j][1] = -inf;
ll tmp = -inf;
bool ok = 0;
if (vis[i][j-1][0]) {
tmp = max(tmp, dp[i][j-1][0]);
ok = 1;
}
if (vis[i][j-1][1]) {
tmp = max(tmp, dp[i][j-1][1]);
ok = 1;
}
if (ok) {
dp[i][j][0] = max(dp[i][j][0], tmp + a[i][j]);
dp[i][j][1] = max(dp[i][j][1], tmp + a[i][j]);
vis[i][j][0] = vis[i][j][1] = 1;
}
}
// from down direction
for (int i = 1; i < n; ++i) {
if (vis[i-1][j][0]) {
dp[i][j][0] = max(dp[i][j][0], dp[i-1][j][0] + a[i][j]);
vis[i][j][0] = 1;
}
}
// from up direction
for (int i = n - 2; i >= 0; --i) {
if (vis[i+1][j][1]) {
dp[i][j][1] = max(dp[i][j][1], dp[i+1][j][1] + a[i][j]);
vis[i][j][1] = 1;
}
}
}
ll ans = -inf;
if (vis[n-1][m-1][0]) {
ans = max(ans, dp[n-1][m-1][0]);
}
if (vis[n-1][m-1][1]) {
ans = max(ans, dp[n-1][m-1][1]);
}
printf("%lld\n", ans);
}
int main() {
int t = 1;
// scanf("%d", &t);
int cas = 1;
while (t--) {
// printf("cas %d:\n", cas++);
solve();
}
}
/*
999
1 1
233
1 2
-9 -8
2 1
-9
-8
*/