题目描述
SA题
朴素大概要个很高的复杂度。
想一个高端一点的暴力,可以只枚举两个后缀,对于这两个后缀任意前缀之间lcp可以列出数学式子,这个式子与这两个后缀的长度以及它们的lcp长度有关。
接下来我们知道lcp等于一段区间height的最小值。
因此写个sa,然后根据height建立笛卡尔树。
接着递归维护需要维护的信息,每次以一个点为lcp值统计答案。
式子因为忘了怎么推就不推啦!
#include<cstdio>
#include<algorithm>
#include<cstring>
//#include<ctime>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
typedef long long ll;
const int maxn=500000+10,mo=998244353;
char s[maxn];
int rank[maxn*2],sa[maxn],height[maxn],a[maxn],b[maxn],c[maxn],d[maxn],len[maxn],le[maxn];
int fa[maxn],left[maxn],right[maxn],sta[maxn],size[maxn],sum[maxn],num[maxn];
int i,j,k,l,r,mid,t,n,m,ans,top,root;
void getsa(){
fo(i,1,n) b[s[i]-'a']++;
fo(i,1,26) b[i]+=b[i-1];
fd(i,n,1) c[b[s[i]-'a']--]=i;
t=0;
fo(i,1,n){
if (s[c[i]]!=s[c[i-1]]) t++;
rank[c[i]]=t;
}
j=1;
while (j<=n){
fo(i,0,n) b[i]=0;
fo(i,1,n) b[rank[i+j]]++;
fo(i,1,n) b[i]+=b[i-1];
fd(i,n,1) c[b[rank[i+j]]--]=i;
fo(i,0,n) b[i]=0;
fo(i,1,n) b[rank[c[i]]]++;
fo(i,1,n) b[i]+=b[i-1];
fd(i,n,1) d[b[rank[c[i]]]--]=c[i];
t=0;
fo(i,1,n){
if (rank[d[i]]!=rank[d[i-1]]||rank[d[i]+j]!=rank[d[i-1]+j]) t++;
c[d[i]]=t;
}
fo(i,1,n) rank[i]=c[i];
if (t==n) break;
j*=2;
}
fo(i,1,n) sa[rank[i]]=i;
}
void getheight(){
k=0;
fo(i,1,n){
if (k) k--;
j=sa[rank[i]-1];
while (i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) k++;
height[rank[i]]=k;
}
}
void dfs(int x){
if (!x) return;
dfs(left[x]);
dfs(right[x]);
size[x]=size[left[x]]+size[right[x]]+1;
sum[x]=(sum[left[x]]+sum[right[x]])%mo;
sum[x]=(sum[x]+len[x])%mo;
num[x]=(num[left[x]]+num[right[x]])%mo;
num[x]=(num[x]+le[x])%mo;
int t=a[x];
j=((ll)t*(t+1)/2)%mo;
k=(ll)(num[left[x]]+le[x])*j%mo*(size[right[x]]+1)%mo;
(ans+=k)%=mo;
k=(ll)(sum[right[x]]+len[x])*j%mo*(size[left[x]]+1)%mo;
(ans+=k)%=mo;
k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo;
(ans+=k)%=mo;
j=((ll)t*(t+1)*(2*t+1)/3)%mo;
k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo;
(ans-=k)%=mo;
j=(ll)t*t%mo*t%mo;
k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo;
(ans+=k)%=mo;
k=(ll)(num[left[x]]+le[x])*(sum[right[x]]+len[x])%mo*t%mo;
(ans+=k)%=mo;
k=(ll)(num[left[x]]+le[x])*(size[right[x]]+1)%mo*t%mo*t%mo;
(ans-=k)%=mo;
k=(ll)(sum[right[x]]+len[x])*(size[left[x]]+1)%mo*t%mo*t%mo;
(ans-=k)%=mo;
}
int main(){
freopen("substring.in","r",stdin);freopen("substring.out","w",stdout);
scanf("%s",s+1);
n=strlen(s+1);
getsa();
getheight();
fo(i,2,n) a[i-1]=height[i],len[i-1]=n-sa[i]+1,le[i-1]=n-sa[i-1]+1;
fo(i,1,n-1){
while (top&&a[i]<a[sta[top]]){
right[fa[sta[top]]]=0;
fa[sta[top]]=i;
right[sta[top]]=left[i];
fa[left[i]]=sta[top];
left[i]=sta[top];
top--;
}
if (top){
fa[i]=sta[top];
right[sta[top]]=i;
}
sta[++top]=i;
}
fo(i,1,n)
if (!fa[i]){
root=i;
break;
}
dfs(root);
ans=(ll)ans*2%mo;
fo(i,1,n){
l=t=n-i+1;
k=((ll)t*(t+1)/2)%mo;
(ans+=(ll)(2*l+1)%mo*k%mo)%=mo;
k=((ll)t*(t+1)*(2*t+1)/3)%mo;
k=-k;
(ans+=k)%=mo;
}
(ans+=mo)%=mo;
printf("%d\n",ans);
//printf("%d\n",clock());
fclose(stdin);fclose(stdout);
}