题意:
求n∗m,n,m≤2000的字符串矩阵里,z有几个,单个z也算
‘z′字形:一个正方形中,第一行和最后一行以及副对角线都是z,其他的无所谓
分析:
首先我们有显然的O(n3)做法,预处理(i,j)的向左l[i][j],向右r[j][j],向左下ld[i][j]的z延伸距离
然后枚举O(n)枚举斜线,O(n2)起始位置和终止位置,根据之前的预处理O(1)判断是不是z
显然枚举斜线是不行的
z有三条线,不妨考虑枚举点,但一个点只能管2条线,能不能把一条线的贡献先加进去呢
显然根据我们的预处理应该枚举右上角的点,就把剩余的下面“一横”添加进去
![]()
考虑逆z的书写方向,把所有的“横线”,根据右上角的点的列标号,存起左端点
这样我们在累加答案的时候,从右往左枚举列,先把这一列所有的“横线”(即左端点)添加进相应的斜线的BIT中,添加y坐标就可以了
一个点(i,j)所在的副对角线用i+j标记
这样枚举的右上角的点(i,j)即可判断出“横折”的长度,即z=min(l[i][j],ld[i][j])
然后判断这条长度为z的“折”(斜线)上“横线”(即左端点)的个数,也就是i+j斜线上区间[j−z+1,j]有几个“横线”(即左端点)
即E=BITi+jsum(j−z+1,j)
这样时间复杂度就优化到O(nmlogm)了
代码:
//
// Created by TaoSama on 2016-02-22
// Copyright (c) 2016 TaoSama. All rights reserved.
//
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <set>
#include <vector>
using namespace std;
#define pr(x) cout << #x << " = " << x << " "
#define prln(x) cout << #x << " = " << x << endl
const int N = 3e3 + 10, INF = 0x3f3f3f3f, MOD = 1e9 + 7;
int n, m;
char s[N][N];
int b[N << 1][N];
int l[N][N], r[N][N], ld[N][N];
vector<pair<int, int> > leftPoints[N];
void add(int *b, int i, int v) {
for(; i <= m; i += i & -i) b[i] += v;
}
int sum(int *b, int i) {
int ret = 0;
for(; i; i -= i & -i) ret += b[i];
return ret;
}
int main() {
#ifdef LOCAL
freopen("C:\\Users\\TaoSama\\Desktop\\in.txt", "r", stdin);
// freopen("C:\\Users\\TaoSama\\Desktop\\out.txt","w",stdout);
#endif
ios_base::sync_with_stdio(0);
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i) scanf("%s", s[i] + 1);
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j)
l[i][j] = s[i][j] == 'z' ? l[i][j - 1] + 1 : 0;
for(int j = m; j; --j)
r[i][j] = s[i][j] == 'z' ? r[i][j + 1] + 1 : 0;
}
for(int i = n; i; --i)
for(int j = 1; j <= m; ++j)
ld[i][j] = s[i][j] == 'z' ? ld[i + 1][j - 1] + 1 : 0;
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j)
leftPoints[j + r[i][j] - 1].push_back({i, j});
long long ans = 0;
for(int j = m; j; --j) {
for(auto &p : leftPoints[j]) {
int x = p.first, y = p.second;
add(b[x + y], y, 1);
}
for(int i = 1; i <= n; ++i) {
int z = min(l[i][j], ld[i][j]);
if(!z) continue;
ans += sum(b[i + j], j) - sum(b[i + j], j - z);
}
}
printf("%I64d\n", ans);
return 0;
}