先预处理得出以每个点向左向上不含有相同字母的子串的最大长度,然后再进行处理。这样就能以52*nm的复杂度解决问题。
#include<bits/stdc++.h>
using namespace std;
#define LL long long int
#define lson rt<<1,l,m
#define rson rt<<1|1,m+1,r
int n,m;
int pos[200];
int len[1010];
int row[1010][1010];
int col[1010][1010];
char mp[1010][1010];
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%s",mp[i]+1);
for(int i=1;i<=n;i++)
{
memset(pos,0,sizeof(pos));
for(int j=1;j<=m;j++)
{
row[i][j]=min(row[i][j-1]+1,j-pos[mp[i][j]]);
pos[mp[i][j]]=j;
}
}
for(int j=1;j<=m;j++)
{
memset(pos,0,sizeof(pos));
for(int i=1;i<=n;i++)
{
col[i][j]=min(col[i-1][j]+1,i-pos[mp[i][j]]);
pos[mp[i][j]]=i;
}
}
LL res=0;
for(int j=1;j<=m;j++)
{
memset(len,0,sizeof(len));
for(int i=1;i<=n;i++)
{
for(int k=0;k<row[i][j];k++)
{
len[k]=min(col[i][j-k],len[k]+1);
if(k) len[k]=min(len[k],len[k-1]);
res+=len[k];
}
for(int k=row[i][j];k<=52;k++)
len[k]=0;
}
}
printf("%lld\n",res);
return 0;
}