题解:求解每个位置向左向右AA串的个数f[x],g[x];
枚举A的长度,每A个位置设一个关键点
每一个A一定仅且跨越一个关键点
然后求出相邻关键点向前向后的最长公共前缀的长度,这会对一段区间的f,g产生影响;
用差分+前缀和统计答案
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int maxn=40009;
int T;
int n;
struct Suffix_Array{
char s[maxn];
int c[maxn],t1[maxn],t2[maxn],sa[maxn];
void Buildsa(){
int m=200;
int *x=t1,*y=t2;
for(int i=1;i<=m;++i)c[i]=0;
for(int i=1;i<=n;++i)c[x[i]=s[i]]++;
for(int i=2;i<=m;++i)c[i]+=c[i-1];
for(int i=n;i>=1;--i)sa[c[x[i]]--]=i;
for(int k=1;k<n;k<<=1){
int p=0;
for(int i=n-k+1;i<=n;++i)y[++p]=i;
for(int i=1;i<=n;++i){
if(sa[i]>k)y[++p]=sa[i]-k;
}
for(int i=1;i<=m;++i)c[i]=0;
for(int i=1;i<=n;++i)c[x[y[i]]]++;
for(int i=2;i<=m;++i)c[i]+=c[i-1];
for(int i=n;i>=1;--i)sa[c[x[y[i]]]--]=y[i];
swap(x,y);
x[sa[1]]=p=1;
for(int i=2;i<=n;++i){
if((y[sa[i]]==y[sa[i-1]])&&(y[sa[i]+k]==y[sa[i-1]+k]))x[sa[i]]=p;
else x[sa[i]]=++p;
}
if(p>=n)break;
m=p;
}
}
int rank[maxn];
int height[maxn];
void Getheight(){
for(int i=1;i<=n;++i)rank[sa[i]]=i;
int k=0;
for(int i=1;i<=n;++i){
if(k)k--;
if(rank[i]==1)continue;
int j=sa[rank[i]-1];
while(s[i+k]==s[j+k])++k;
height[rank[i]]=k;
}
}
int f[maxn][20];
void STinit(){
for(int i=1;i<=n;++i)f[i][0]=height[i];
for(int j=1;j<=19;++j){
for(int i=1;i+(1<<j)-1<=n;++i){
f[i][j]=min(f[i][j-1],f[i+(1<<(j-1))][j-1]);
}
}
}
int Querymin(int l,int r){
int k=log2(r-l+1.5);
return min(f[l][k],f[r-(1<<k)+1][k]);
}
int Lcp(int l,int r){
l=rank[l];
r=rank[r];
if(l>r)swap(l,r);
return Querymin(l+1,r);
}
void SAinit(){
memset(s,0,sizeof(s));
memset(c,0,sizeof(c));
memset(t1,0,sizeof(t1));
memset(t2,0,sizeof(t2));
memset(sa,0,sizeof(sa));
memset(rank,0,sizeof(rank));
memset(height,0,sizeof(height));
memset(f,0,sizeof(f));
}
}a[2];
int tb[2][maxn];
int main(){
scanf("%d",&T);
while(T--){
a[0].SAinit();
a[1].SAinit();
memset(tb,0,sizeof(tb));
scanf("%s",a[0].s+1);
n=strlen(a[0].s+1);
for(int i=1;i<=n;++i)a[1].s[i]=a[0].s[n-i+1];
a[0].Buildsa();
a[1].Buildsa();
a[0].Getheight();
a[1].Getheight();
a[0].STinit();
a[1].STinit();
// for(int i=1;i<=n;++i)printf("%d ",a[0].sa[i]);
// cout<<endl;
for(int k=1;k<=n;++k){
for(int i=k;i+k<=n;i+=k){
int lenl=a[1].Lcp(n-i+1,n-(i+k)+1);
int lenr=a[0].Lcp(i,i+k);
// printf("len=%d lenl=%d lenr=%d\n",k,lenl,lenr);
lenl=min(lenl,k);
lenr=min(lenr,k);
if(lenl+lenr-1<k)continue;
tb[0][i+2*k-lenl]++;tb[0][i+k+lenr]--;
tb[1][i-lenl+1]++;tb[1][i+lenr-k+1]--;
}
}
// for(int i=1;i<=n;++i)printf("%d ",tb[0][i]);
// cout<<endl;
// for(int i=1;i<=n;++i)printf("%d ",tb[1][i]);
// cout<<endl;
for(int i=2;i<=n;++i){
tb[0][i]+=tb[0][i-1];
tb[1][i]+=tb[1][i-1];
}
long long ans=0;
for(int i=1;i<n;++i){
ans=ans+1LL*tb[0][i]*tb[1][i+1];
}
printf("%lld\n",ans);
}
return 0;
}