如代码注释。
#include<cstdio>
#include<cstring>
#include<map>
using namespace std;
int nx[1001000];
const int modd=998244353;
typedef unsigned long long ull;
map<ull,int> ma;
void getnx(unsigned int s1[],int n)
{
nx[1]=0;
for(int i=2,j=1;i<=n+1;)//next[i] 表示最大的x,满足s1[1 : x - 1] 是s1[1 : i - 1] 的后缀。即i失配该把i变成啥
{
nx[i]=j;
while(j&&s1[j]!=s1[i])j=nx[j];
j++,i++;
}
}
char ss[1001000];
const int NN=1e5+10;
unsigned int* s[NN];
int lenth[NN];
int cnt[1001000];
const unsigned int jinzhi=31;
int main(){
int n;scanf("%d\n",&n);
for(int i=1;i<=n;i++){
scanf("%s",ss+1);
int len=strlen(ss+1);
lenth[i]=len;
s[i]=new unsigned int[len+2];
for(int j=1;j<=len;j++){
s[i][j]=ss[j]-'a'+1;
}
ull pow=1;
ull ha=0;
for(int j=len;j>=1;j--){
ha=pow*s[i][j]+ha;
pow*=jinzhi;
ma[ha]+=1;
}
}
long long ans=0;
for(int i=1;i<=n;i++){
int len=lenth[i];
getnx(s[i],len);
for(int j=1;j<=len;j++){//这里nxi变成1到i的最长公共前后缀长度。
nx[j]=nx[j+1]-1;
}
ull ha=0;
for(int j=1;j<=len;j++){
ha=ha*jinzhi+s[i][j];
cnt[j]=ma[ha];
}
for(int j=1;j<=len;j++){
cnt[nx[j]]-=cnt[j];
}
for(int j=1;j<=len;j++){
ans+=1ll*cnt[j]*j*j%modd;
ans=ans%modd;
}
}
printf("%lld\n",ans);
return 0;
}