题意
给定一个 n × m n \times m n×m 的带权矩阵,求从 ( 1 , 1 ) (1, 1) (1,1) 到 ( n , m ) (n, m) (n,m) 的两个没有交点的路径和的最大值。
分析
虽然题目限制的是同一个点不能重复经过两次,但如果一个最优解路径是有交点的,那么一定可以转化为一个路径没有交点的最优解。证明在证明一。还可以这样思考,可以把交点处的权值只加一次。因为每个点只能使用一次,那么第二次使用时的权值就为 0 0 0 ,这样实际就转化成了 方格取数 ,事实上两题本质是一样的,方格取数的代码可以通过本题。
状态表示:f[k][i1][i2]
表示两个人同时走了 k
步,第一个人在 (i1, k - i1)
处,第二个人在 (i2, k - i2)
处的所有走法的最大分值。
状态计算: 根据最后一步两个人的走法可以分为 4 4 4 种转移方式。
- 两人同时向右走,最大分值是
f[k - 1][i1][i2] + v
。 - 第一个人向右走,第二个人向下走,最大分值是
f[k - 1][i1][i2 - 1] + v
。 - 第一个人向下走,第二个人向右走,最大分值是
f[k - 1][i1 - 1][i2] + v
。 - 两个人同时向下走,最大分值是
f[k - 1][i1 - 1][i2 - 1] + v
。
这里的 v
,只有起点和终点是特殊的 v = w[i1][k - i1]
,因为起点和终点的权值只需要加一次,而且起点终点还具有 i1 == i2
的性质。其余所有点都是一样的 v = w[i1][k - i1] + w[i2][k - i2]
。
时间复杂度: 一共有 O ( n 3 ) O(n^3) O(n3) 个状态,每个状态需要 O ( 1 ) O(1) O(1) 的计算量。因此,总时间复杂度为 O ( n 3 ) O(n ^ 3) O(n3) 。
代码
递归写法:
// Problem: 传纸条
// URL: https://www.acwing.com/problem/content/description/277/
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int N = 55, INF = 0x3f3f3f3f;
int n, m;
int w[N][N];
int f[N * 2][N][N];
int dp(int k, int i1, int i2) {
// 记忆化
int &res = f[k][i1][i2];
if (res != -1) {
return res;
}
// 越界判断
if (i1 <= 0 || i2 <= 0 || i1 >= k || i2 >= k) {
return 0;
}
// 状态转移
if (i1 != i2 || k == 2 || k == n + m) {
int j1 = k - i1, j2 = k - i2;
int v = w[i1][j1];
if (i1 != i2) {
v += w[i2][j2];
}
for (int a = 0; a <= 1; a ++ ) {
for (int b = 0; b <= 1; b ++ ) {
res = max(res, dp(k - 1, i1 - a, i2 - b) + v);
}
}
} else {
res = 0;
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i ++ ) {
for (int j = 1; j <= m; j ++ ) {
cin >> w[i][j];
}
}
memset(f, -1, sizeof f);
cout << dp(n + m, n, n) << "\n";
return 0;
}
循环写法:
// Problem: 传纸条
// URL: https://www.acwing.com/problem/content/description/277/
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int N = 55;
int n, m;
int w[N][N];
int f[N * 2][N][N];
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i ++ ) {
for (int j = 1; j <= m; j ++ ) {
cin >> w[i][j];
}
}
// dp
for (int k = 2; k <= n + m; k ++ ) {
for (int i1 = max(1, k - m); i1 <= min(n, k - 1); i1 ++ ) {
for (int i2 = max(1, k - m); i2 <= min(n, k - 1); i2 ++ ) {
for (int a = 0; a <= 1; a ++ ) {
for (int b = 0; b <= 1; b ++ ) {
if (i1 != i2 || k == 2 || k == n + m) {
int j1 = k - i1, j2 = k - i2;
int v = w[i1][j1];
if (i1 != i2) {
v += w[i2][j2];
}
f[k][i1][i2] = max(f[k][i1][i2], f[k - 1][i1 - a][i2 - b] + v);
}
}
}
}
}
}
cout << f[n + m][n][n] << "\n";
return 0;
}