题意:给你一个01矩阵,问此矩阵有多少个和恰好为k的子矩形。
思路:分治,对于当前矩形,用一条中线把矩形分成两半,分治之后计算跨过中线的矩形个数。更具体的来说(假设划了一条水平中线),我们枚举矩形左右边界,然后用指针维护一下到中线的连续和为k的边界。之后通过差分就可以计算出对应的左右边界的矩形的贡献数目。对于一个n * m的矩阵,计算贡献的时间复杂度是O(n * (m * k + n))的,带有n * n项,所以计算的时候需要用交替画水平线和竖直线,不然就超时了。总复杂度O(n * m * k * ( log(n) + log(m) ) );
代码:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2510;
int sum[maxn][maxn];
char s[maxn][maxn];
int n, m, K;
long long ans;
int p1[10], p2[10];
int cal(int x1, int y1, int x2, int y2) {
return sum[x2][y2] - sum[x1 - 1][y2] - sum[x2][y1 - 1] + sum[x1 - 1][y1 - 1];
}
void div(int x1, int y1, int x2, int y2, bool flag) {
if(x1 > x2 || y1 > y2) return;
if(x1 == x2 && y1 == y2) {
ans += (cal(x1, y1, x2, y2) == K);
return;
}
if(flag) {
int mid = (x1 + x2) >> 1;
div(x1, y1, mid, y2, !flag);
div(mid + 1, y1, x2, y2, !flag);
for (int i = y1; i <= y2; i++) {
for (int k = 0; k <= K; k++) {
p1[k] = mid, p2[k] = mid + 1;
}
for (int j = y2; j >= i; j--) {
for (int k = 0; k <= K; k++) {
while(p1[k] >= x1 && cal(p1[k], i, mid, j) <= k) p1[k]--;
while(p2[k] <= x2 && cal(mid + 1, i, p2[k], j) <= k) p2[k]++;
}
for (int k = 1; k < K; k++) {
ans += (p1[k - 1] - p1[k]) * (p2[K - k] - p2[K - k - 1]);
}
if(K > 0) {
ans += (mid - p1[0]) * (p2[K] - p2[K - 1]);
ans += (p2[0] - mid - 1) * (p1[K - 1] - p1[K]);
} else if(K == 0) {
ans += (mid - p1[0]) * (p2[0] - mid - 1);
}
}
}
}
else {
int mid = (y1 + y2) >> 1;
div(x1, y1, x2, mid, !flag);
div(x1, mid + 1, x2, y2, !flag);
for (int i = x1; i <= x2; i++) {
for (int k = 0; k <= K; k++) {
p1[k] = mid, p2[k] = mid + 1;
}
for (int j = x2; j >= i; j--) {
for (int k = 0; k <= K; k++) {
while(p1[k] >= y1 && cal(i, p1[k], j, mid) <= k) p1[k]--;
while(p2[k] <= y2 && cal(i, mid + 1, j, p2[k]) <= k) p2[k]++;
}
for (int k = 1; k < K; k++) {
ans += (p1[k - 1] - p1[k]) * (p2[K - k] - p2[K - k - 1]);
}
if(K > 0) {
ans += (mid - p1[0]) * (p2[K] - p2[K - 1]);
ans += (p2[0] - mid - 1) * (p1[K - 1] - p1[K]);
} else if(K == 0) {
ans += (mid - p1[0]) * (p2[0] - mid - 1);
}
//printf("%d %d %d %d %d %d\n", x1, y1, x2, y2, 2, ans);
}
}
}
}
int main() {
//freopen("out.txt", "r", stdin);
scanf("%d%d%d", &n, &m, &K);
for (int i = 1; i <= n; i++) {
scanf("%s", s[i] + 1);
}
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++) {
sum[i][j] = sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + (s[i][j] == '1');
}
div(1, 1, n, m, 0);
printf("%lld\n", ans);
}