首先我们考虑偶数个点和奇数个点的方阵枚举中心方式不太相同,我们用类似manacher的处理方法,填上一堆0,把他们全都变成奇数的情况。然后我们枚举每一个点作为中心,二分答案找到以这个点为中心最大的合法方阵。就可以直接统计这个点对答案的贡献了。这样已经是 O(n2logn) 的了,我们需要O(1)判断一个方阵是否上下左右均对称。类似不用manacher求最长回文子串的方法,把这个子串镜像过来求最长公共子串,我们分别做出这个矩阵的上下镜面和左右镜面,然后每次就只需要判定这三个方阵区域是否相同。可以用二维hash预处理来O(1)判断。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define ll long long
#define N 2010
#define inf 0x3f3f3f3f
#define k1 1000003
#define k2 101
#define uint unsigned int
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x*f;
}
int n,m,a[N][N],ans=0;
uint hs[3][N][N],bin1[N],bin2[N];
inline uint hash(int op,int x1,int y1,int x2,int y2){
uint res=hs[op][x2][y2]-hs[op][x1-1][y2]*bin1[x2-x1+1]-hs[op][x2][y1-1]*bin2[y2-y1+1];
return res+hs[op][x1-1][y1-1]*bin2[y2-y1+1]*bin1[x2-x1+1];
}
inline bool jud(int x,int y,int len){
int x1=x-len+1,x2=x+len-1,y1=y-len+1,y2=y+len-1;
uint v1=hash(0,x1,y1,x2,y2),v2=hash(1,n-x2+1,y1,n-x1+1,y2),v3=hash(2,x1,m-y2+1,x2,m-y1+1);
if(v1!=v2||v1!=v3) return 0;
return 1;
}
int main(){
// freopen("a.in","r",stdin);
n=read();m=read();bin1[0]=bin2[0]=1;
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) a[i*2-1][j*2-1]=read();
n=n*2-1;m=m*2-1;
for(int i=1;i<=n;++i) bin1[i]=bin1[i-1]*k1;
for(int i=1;i<=m;++i) bin2[i]=bin2[i-1]*k2;
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[0][i][j]=hs[0][i][j-1]*k2+a[i][j];
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[0][i][j]+=hs[0][i-1][j]*k1;
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[1][i][j]=hs[1][i][j-1]*k2+a[n-i+1][j];
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[1][i][j]+=hs[1][i-1][j]*k1;
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[2][i][j]=hs[2][i][j-1]*k2+a[i][m-j+1];
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j) hs[2][i][j]+=hs[2][i-1][j]*k1;
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j){
if(i+j&1) continue;
int l=1,r=min(min(i,n-i+1),min(j,m-j+1));
while(l<=r){
int mid=l+r>>1;
if(jud(i,j,mid)) l=mid+1;else r=mid-1;
}l--;if(i&1) ans+=l+1>>1;else ans+=l>>1;
}printf("%d\n",ans);
return 0;
}