把两个字符串连起来,后缀数组求出height数组后用单调栈维护。记录排名i的后缀的height可以向前延伸到最远时,其间l[i]个后缀属于字符串1,r[i]个后缀属于字符串2。后缀i每弹出一个栈内元素,贡献为当前的l[i]*r[i]再乘弹出元素与前一个栈内元素的差或与height[i]的差的最小值。注意后缀i的l和r数组初始值应由后缀i-1属于哪个字符串决定。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=400010;
int n1,n,m,rnk[N],tp[N],sa[N],hgt[N],tak[N],top,q[N],pos[N],l[N],r[N];
char str1[N];
long long ans;
inline void sort1(){
for(int i=1;i<=m;i++)tak[i]=0;
for(int i=1;i<=n;i++)tak[rnk[i]]++;
for(int i=2;i<=m;i++)tak[i]+=tak[i-1];
for(int i=n;i>=1;i--)sa[tak[rnk[tp[i]]]--]=tp[i];
}
int main(){
scanf("%s",str1+1);
n=n1=strlen(str1+1);
for(int i=1;i<=n1;i++){rnk[i]=str1[i]-'a'+1;tp[i]=i;}
str1[n+1]='&';rnk[n+1]=27;tp[n+1]=n+1;
scanf("%s",str1+n1+2);n=strlen(str1+1);
for(int i=n1+2;i<=n;i++){rnk[i]=str1[i]-'a'+1;tp[i]=i;}
m=27;sort1();
for(int p=0,w=1;p<n;w<<=1,m=p){
p=0;
for(int i=1;i<=w;i++)tp[++p]=n-w+i;
for(int i=1;i<=n;i++)if(sa[i]>w)tp[++p]=sa[i]-w;
sort1();swap(tp,rnk);
rnk[sa[1]]=p=1;
for(int i=2;i<=n;i++){
if(tp[sa[i]]!=tp[sa[i-1]]||tp[sa[i]+w]!=tp[sa[i-1]+w])p++;
rnk[sa[i]]=p;
}
}
for(int x,j=0,i=1;i<=n;i++){
if(j)j--;
x=sa[rnk[i]-1];
while(str1[i+j]==str1[x+j])j++;
hgt[rnk[i]]=j;
}
ans=0;
memset(l,0,sizeof l);memset(r,0,sizeof r);
top=1;q[0]=q[top]=0;pos[1]=1;
for(int i=2;i<=n;i++){
if(sa[i-1]<=n1)l[i]=1;
else r[i]=1;
while(top>0&&hgt[i]<=q[top]){
l[i]+=l[pos[top]];r[i]+=r[pos[top]];
ans+=1ll*l[i]*r[i]*(q[top]-max(q[top-1],hgt[i]));
top--;
}
q[++top]=hgt[i];pos[top]=i;
}
printf("%lld",ans);
return 0;
}