这道题是一道很好的二位前缀和问题。
然而码量有点大。
下面规定 n n n 表示行, m m m 表示列, n , m n,m n,m 同阶。
即计算复杂度的时候视 O ( n m ) O(nm) O(nm) 为 O ( n 2 ) O(n^2) O(n2)。
首先预处理 s u m i , j sum_{i,j} sumi,j 表示从 ( 1 , 1 ) (1,1) (1,1) 到 ( i , j ) (i,j) (i,j) 的和,也就是二维前缀和,这里略过。
考虑将一个矩阵分成 3 块无非 6 种情况,如下:
上面 6 种情况又分为 2 类:
第一类是前面两种全横排与全竖排。
对于这一类情况,我们需要 O ( n 2 ) O(n^2) O(n2) 枚举两个分割行,然后 O ( 1 ) O(1) O(1) 求出答案。
因此这里引入两个辅助数组 L i n e i , j , C o l i , j Line_{i,j},Col_{i,j} Linei,j,Coli,j。
- L i n e i , j Line_{i,j} Linei,j 表示第 i i i 行到第 j j j 行之间的最大 k × k k \times k k×k 子矩阵元素和。
- C o l i , j Col_{i,j} Coli,j 表示第 i i i 列到第 j j j 列之间的最大 k × k k \times k k×k 子矩阵元素和。
然后考虑如何预处理出这两个数组。
下面以 L i n e i , j Line_{i,j} Linei,j 为例。
首先,对于形如 L i n e i , i + k − 1 Line_{i,i+k-1} Linei,i+k−1 的结果,由于总共只有 k k k 行,此时我们可以枚举这 k k k 行里面的子矩阵,至多只有 m − k + 1 m-k+1 m−k+1 个。
这部分复杂度是 O ( n 2 ) O(n^2) O(n2)。
然后对于任意的 L i n e i , j Line_{i,j} Linei,j,如果 j − i + 1 < k j-i+1<k j−i+1<k,说明此时行数非法,空着就好(无贡献),否则采用区间 DP 方式转移: L i n e i , j = max ( L i n e i + 1 , j , L i n e i , j − 1 ) Line_{i,j}=\max(Line_{i+1,j},Line_{i,j-1}) Linei,j=max(Linei+1,j,Linei,j−1)。
处理完 L i n e , C o l Line,Col Line,Col 之后,我们就可以 O ( n 2 ) O(n^2) O(n2) 枚举, O ( 1 ) O(1) O(1) 算贡献了。
第二类是后面的四种情况,图我搬过来了:
细心观察上图可以发现,这一类情况的图中都有一个交点(即图中的红点)。
因此我们考虑 O ( n 2 ) O(n^2) O(n2) 枚举这个交点。
但是这样我们依然要 O ( 1 ) O(1) O(1) 算贡献。
因此我们除了前面的 L i n e , C o l Line,Col Line,Col,还要引入一个辅助数组:
f i , j , 0 / 1 / 2 / 3 f_{i,j,0/1/2/3} fi,j,0/1/2/3 表示以 ( i , j ) (i,j) (i,j) 为分界点,左上/右上/左下/右下 的最大 k × k k \times k k×k 子矩阵元素和。包括 ( i , j ) (i,j) (i,j) 这个点。
也就是向下面这张图:
接下来以预处理 f i , j , 0 f_{i,j,0} fi,j,0 为例:
还是看图。
假设右下角的点为 ( i , j ) (i,j) (i,j),我们要算 f i , j , 0 f_{i,j,0} fi,j,0。
这个矩形被我分成了 3 类:
- 绿色的是 f i , j − 1 , 0 f_{i,j-1,0} fi,j−1,0。
- 蓝色的是 f i − 1 , j , 0 f_{i-1,j,0} fi−1,j,0。
- 黄色的是 ( i − k + 1 , j − k + 1 ) (i-k+1,j-k+1) (i−k+1,j−k+1) 到 ( i , j ) (i,j) (i,j) 这一块矩形的和。
显然答案只能是上述 3 者的最大值,于是我们可以 O ( n 2 ) O(n^2) O(n2) 处理完 f i , j , 0 f_{i,j,0} fi,j,0。
别的同理。
综上,对于第二类情况,我们就可以 O ( n 2 ) O(n^2) O(n2) 枚举交点, O ( 1 ) O(1) O(1) 计算答案了。
最后的时间复杂度是 O ( n 2 ) O(n^2) O(n2)。
几个小细节:
- 注意在计算 f i , j , 0 / 1 / 2 / 3 , L i n e i , j , C o l i , j f_{i,j,0/1/2/3},Line_{i,j},Col_{i,j} fi,j,0/1/2/3,Linei,j,Coli,j 的时候是否合法的判断。
- 在转移的时候要注意分割线不能算两次。
比如说下面这样:
此时的正确答案计算式的其中一种为 f i , j , 0 + f i , j + 1 , 1 + L i n e i + 1 , n f_{i,j,0}+f_{i,j+1,1}+Line_{i+1,n} fi,j,0+fi,j+1,1+Linei+1,n。
千万注意分界线不能算两次。
- 注意 n , m n,m n,m 不能弄错。
- 在计算二维前缀和的时候注意不能出现左上角+右下角与左下角+右上角计算两种情况混着用。
换句话说,假设当前矩阵左上,右上,左下,右下分别为 ( 2 , 2 ) , ( 2 , 4 ) , ( 4 , 2 ) , ( 4 , 4 ) (2,2),(2,4),(4,2),(4,4) (2,2),(2,4),(4,2),(4,4),不能出现传进去 ( 4 , 2 ) , ( 2 , 4 ) (4,2),(2,4) (4,2),(2,4) 并且将其当成矩形左上角与右下角 的错误。 - 不能开
long long
,否则会 MLE。
至于答案手算一下就会发现实际上不会炸int
。
代码:
/*
========= Plozia =========
Author:Plozia
Problem:P3625 [APIO2009]采油区域
Date:2021/5/14
========= Plozia =========
*/
#include <bits/stdc++.h>
typedef long long LL;
const int MAXN = 1500 + 10;
int n, m, k, a[MAXN][MAXN], sum[MAXN][MAXN], f[MAXN][MAXN][4], Line[MAXN][MAXN], Col[MAXN][MAXN];
int read()
{
int sum = 0, fh = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
return sum * fh;
}
int Max(int fir, int sec) { return (fir > sec) ? fir : sec; }
int Min(int fir, int sec) { return (fir < sec) ? fir : sec; }
int Get(int r1, int c1, int r2, int c2)
{
if (r1 > r2) std::swap(r1, r2); if (c1 > c2) std::swap(c1, c2);
return sum[r2][c2] - sum[r2][c1 - 1] - sum[r1 - 1][c2] + sum[r1 - 1][c1 - 1];
}
void init()
{
f[k][k][0] = Get(1, 1, k, k);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
{
if (i > k) { f[i][j][0] = Max(f[i][j][0], f[i - 1][j][0]); }
if (j > k) { f[i][j][0] = Max(f[i][j][0], f[i][j - 1][0]); }
if (i - k + 1 > 0 && j - k + 1 > 0) { f[i][j][0] = Max(f[i][j][0], Get(i - k + 1, j - k + 1, i, j)); }
}
f[k][m - k + 1][1] = Get(1, m - k + 1, k, m);
for (int i = 1; i <= n; ++i)
for (int j = m; j >= 1; --j)
{
if (i > k) { f[i][j][1] = Max(f[i][j][1], f[i - 1][j][1]); }
if (j < m - k + 1) { f[i][j][1] = Max(f[i][j][1], f[i][j + 1][1]); }
if (i - k + 1 > 0 && j + k - 1 <= m) { f[i][j][1] = Max(f[i][j][1], Get(i - k + 1, j, i, j + k - 1)); }
}
f[n - k + 1][k][2] = Get(n - k + 1, 1, n, k);
for (int i = n; i >= 1; --i)
for (int j = 1; j <= m; ++j)
{
if (i < n - k + 1) { f[i][j][2] = Max(f[i][j][2], f[i + 1][j][2]); }
if (j > k) { f[i][j][2] = Max(f[i][j][2], f[i][j - 1][2]); }
if (i + k - 1 <= n && j - k + 1 > 0) { f[i][j][2] = Max(f[i][j][2], Get(i, j - k + 1, i + k - 1, j)); }
}
f[n - k + 1][m - k + 1][3] = Get(n - k + 1, m - k + 1, n, m);
for (int i = n; i >= 1; --i)
for (int j = m; j >= 1; --j)
{
if (i < n - k + 1) { f[i][j][3] = Max(f[i][j][3], f[i + 1][j][3]); }
if (j < m - k + 1) { f[i][j][3] = Max(f[i][j][3], f[i][j + 1][3]); }
if (i + k - 1 <= n && j + k - 1 <= m) { f[i][j][3] = Max(f[i][j][3], Get(i, j, i + k - 1, j + k - 1)); }
}
}
int main()
{
n = read(), m = read(), k = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
a[i][j] = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
sum[i][j] = sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + a[i][j];
init();
for (int i = 1; i + k - 1 <= n; ++i)
{
for (int j = 1; j <= m; ++j)
{
if (j >= k) { Line[i][i + k - 1] = Max(Line[i][i + k - 1], Get(i, j - k + 1, i + k - 1, j)); }
if (j + k - 1 <= m) { Line[i][i + k - 1] = Max(Line[i][i + k - 1], Get(i, j, i + k - 1, j + k - 1)); }
}
}
for (int len = k + 1; len <= n; ++len)
{
for (int i = 1; i <= n; ++i)
{
int j = i + len - 1; if (j > n) break ;
Line[i][j] = Max(Line[i + 1][j], Line[i][j - 1]);
}
}
for (int j = 1; j + k - 1 <= m; ++j)
{
for (int i = 1; i <= n; ++i)
{
if (i >= k) { Col[j][j + k - 1] = Max(Col[j][j + k - 1], Get(i - k + 1, j, i, j + k - 1)); }
if (i + k - 1 <= n) { Col[j][j + k - 1] = Max(Col[j][j + k - 1], Get(i, j, i + k - 1, j + k - 1)); }
}
}
for (int len = k + 1; len <= m; ++len)
{
for (int i = 1; i <= m; ++i)
{
int j = i + len - 1; if (j > m) break ;
Col[i][j] = Max(Col[i + 1][j], Col[i][j - 1]);
}
}
int ans = 0;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
{
ans = Max(ans, Col[1][j - 1] + f[i][j][1] + f[i + 1][j][3]);
ans = Max(ans, Col[j + 1][m] + f[i][j][0] + f[i + 1][j][2]);
ans = Max(ans, f[i][j][0] + f[i][j + 1][1] + Line[i + 1][n]);
ans = Max(ans, f[i][j][2] + f[i][j + 1][3] + Line[1][i - 1]);
}
for (int i = k; i <= n - k + 1; ++i)
for (int j = i + k - 1; j <= n - k + 1; ++j)
ans = Max(ans, Line[1][i] + Line[i + 1][j - 1] + Line[j][n]);
for (int i = k; i <= m - k + 1; ++i)
for (int j = i + k - 1; j <= m - k + 1; ++j)
ans = Max(ans, Col[1][i] + Col[i + 1][j - 1] + Col[j][m]);
printf("%d\n", ans); return 0;
}