对称的正方形
时间限制: 1000 M S 1000MS 1000MS 内存限制: 128 M B 128 MB 128MB
问题描述
O r e z Orez Orez很喜欢搜集一些神秘的数据,并经常把它们排成一个矩阵进行研究。最近, O r e z Orez Orez又得到了一些数据,并已经把它们排成了一个 n n n行 m m m列的矩阵。通过观察, O r e z Orez Orez发现这些数据蕴涵了一个奇特的数,就是矩阵中上下对称且左右对称的正方形子矩阵的个数。 O r e z Orez Orez自然很想知道这个数是多少,可是矩阵太大,无法去数。只能请你编个程序来计算出这个数。
输入格式
文件的第一行为两个整数 n n n和 m m m。接下来 n n n行每行包含 m m m个正整数,表示 O r e z Orez Orez得到的矩阵。
输出格式
文件中仅包含一个整数 a n s w e r answer answer,表示矩阵中有 a n s w e r answer answer个上下左右对称的正方形子矩阵。
样例输入
5 5 5 5 5 5
4 4 4 2 2 2 4 4 4 4 4 4 4 4 4
3 3 3 1 1 1 4 4 4 4 4 4 3 3 3
3 3 3 5 5 5 3 3 3 3 3 3 3 3 3
3 3 3 1 1 1 5 5 5 3 3 3 3 3 3
4 4 4 2 2 2 1 1 1 2 2 2 4 4 4样例输出
27 27 27
数据范围
对于 30 30 30%的数据 n , m ≤ 100 n,m≤100 n,m≤100
对于 100 100 100%的数据 n , m ≤ 1000 n,m≤1000 n,m≤1000 ,矩阵中的数的大小 ≤ 1 0 9 ≤10^9 ≤109
解析
听说有很多神仙是用
M
a
n
a
c
h
e
r
Manacher
Manacher做的 ,然而本蒟蒻并不会 。
所以我们用一种比较简单粗暴且易于理解的算法——
h
a
s
h
hash
hash来替代。(说白了就是弱)
首先我们会发现两个很 (显然且) 有用的性质:
- 如果一个正方形子矩阵是对称的且边长 > 2 >2 >2,那么比它小一圈的正方形子矩阵也一定是对称的。
- 正方形子矩阵的对称中心至多只有
O
(
2
n
m
)
O(2nm)
O(2nm)个
那么我们可以枚举正方形子矩阵的对称中心,并二分此对称中心的最大边长。
至于判定该正方形子矩阵是否对称,我们可以通过矩阵hash来解决。
设一个矩阵 a [ 1.. i ] [ 1.. j ] a[1..i][1..j] a[1..i][1..j]的 h a s h hash hash值为 Σ p 1 i ∗ p 2 j ∗ a [ i ] [ j ] Σp_1^i*p_2^j*a[i][j] Σp1i∗p2j∗a[i][j]
我们把原矩阵、原矩阵上下翻转、原矩阵左右翻转分别做一次 h a s h hash hash,判定时只要把对应矩阵的 h a s h hash hash值用二维前缀和求出来并简单处理一下行差、列差对 p 1 , p 2 p_1,p_2 p1,p2乘方次数的影响之后判断是否相等即可。
T i p s : Tips: Tips:
- 本题时限较紧,请提前预处理 p 1 i , p 2 j p_1^i,p_2^j p1i,p2j
- 如果你对自己的常数不是非常自信的话请不要写双
h
a
s
h
hash
hash,写了也不要用
p
a
i
r
pair
pair
cyl大佬已经身先士卒地T了 - 别把 m m m打成 n n n,本蒟蒻对此已经不想说什么了
- 枚举偶数边长时不一定有答案
代码
#include <cstdio>
#define ll long long
using namespace std;
const int maxn = 1005;
const int p1 = 29;
const int p2 = 31;
const int mod = 1e9 + 7;
int n , m;
int a[maxn][maxn] , b[maxn][maxn] , c[maxn][maxn];
ll pow_x[maxn] , pow_y[maxn];
int min(int x , int y){return x < y ? x : y;}
int read()
{
char ch = getchar(); bool f = 1;
while(ch < '0' || ch > '9') f &= ch != '-' , ch = getchar();
int res = 0;
while(ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48) , ch = getchar();
return f ? res : -res;
}
void pow_init()
{
pow_x[0] = pow_y[0] = 1;
for(int i = 1;i <= n;i++) pow_x[i] = pow_x[i - 1] * p1 % mod;
for(int i = 1;i <= m;i++) pow_y[i] = pow_y[i - 1] * p2 % mod;
}
struct HASH
{
private:
ll s[maxn][maxn];
public:
void init(int (*x)[maxn])
{
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
ll tmp = pow_x[i] * pow_y[j] % mod * x[i][j] % mod;
s[i][j] = ((s[i - 1][j] + s[i][j - 1]) % mod - s[i - 1][j - 1] + tmp + mod) % mod;
}
}
ll sum(int bx , int by , int ex , int ey){return ((s[ex][ey] - s[bx - 1][ey] - s[ex][by - 1] + s[bx - 1][by - 1]) % mod + mod) % mod;}
}hash1 , hash2 , hash3;
bool check1(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len , ey1 = j + len;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
bool check2(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len + 1 , ey1 = j + len + 1;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
int main()
{
n = read() , m = read();
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++) a[i][j] = b[i][m - j + 1] = c[n - i + 1][j] = read();
pow_init();
hash1.init(a) , hash2.init(b) , hash3.init(c);
int ans = 0;
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i) , min(j - 1 , m - j));
while(l < r)
{
int mid = l + r + 1 >> 1;
if(check1(mid , i , j)) l = mid;
else r = mid - 1;
}
ans += l + 1;
}
for(int i = 1;i < n;i++)
for(int j = 1;j < m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i - 1) , min(j - 1 , m - j - 1)) , res = -1;
while(l <= r)
{
int mid = l + r >> 1;
if(check2(mid , i , j)) res = mid , l = mid + 1;
else r = mid - 1;
}
ans += res + 1;
}
printf("%d\n",ans);
return 0;
}