Description
Input
一行,一个字符串S
Output
一行,一个整数,表示所求值
Sample Input
Sample Output
54
HINT
2<=N<=500000,S由小写英文字母组成
题面似乎非常……lcp的第一感觉就是后缀数组后缀自动机什么的了= =
然后发现其实确实求一遍sa和height就可以做了。
网上题解基本都是单调栈……
我的方法似乎不太一样。
对于后缀x和后缀y(假设x和y是它们的排名),它们的lcp就是x+1~y间的min{height}
那么对于l~r,求出l+1~r的min height为m位置,
显然,l~(m-1)为一组,m~r为一组,因为这两组里的后缀lcp两两都是height[m];
如果说l~m为一组,你会发现m+1和m的lcp不一定是height[m].
然后直接统计m两边的后缀的贡献即可。
我具体写一下:
考虑l~(m-1)里有个a,m~r里的后缀长度分别为b1,b2,……b(r-m+1),
那么a的贡献是
len(a)+len(b1)-2*height[m]+
len(a)+len(b2)-2*height[m]+
……
len(a)+len(b(r-m+1))-2*height[m].
我们把它们加起来,而一段里的后缀长度的和我们可以在外面预处理出来的,
假设这个l~r的前缀和为S[l..r],则a的贡献就是:
(r-m+1)*len(a)+S[m..r]-2*(r-m+1)*height[m];
接下来我们对于l~(m-1)里的每个这样的a的单独贡献进行累加,
就是:
(r-m+1)*len(a1)+S[m..r]-2*(r-m+1)*height[m]+
(r-m+1)*len(a2)+S[m..r]-2*(r-m+1)*height[m]+
(r-m+1)*len(a3)+S[m..r]-2*(r-m+1)*height[m]+
……
(r-m+1)*len(a(m-l))+S[m..r]-2*(r-m+1)*height[m]
=
(r-m+1)*S[l..m-1]+(m-l)*S[m..r]-2*(m-l)*(r-m+1)*height[m]
然后对(l,m-1)以及(m,r)进行递归计算即可。
由于要预处理st表,所以时间复杂度是O(nlogn+n)
注意long long哪里都不要忘。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int
N=500005;
int n;
int cnta[N],cntb[N],a[N],b[N];
int tsa[N],rank[N<<1],sa[N];
int st[N][20];
ll sum[N],H[N],ans;
char s[N];
void get_SA(){
for (int i=0;i<26;i++) cnta[i]=0;
for (int i=1;i<=n;i++) cnta[s[i]-97]++;
for (int i=1;i<26;i++) cnta[i]+=cnta[i-1];
for (int i=n;i;i--) sa[cnta[s[i]-97]--]=i;
rank[sa[1]]=1;
for (int i=2;i<=n;i++)
rank[sa[i]]=rank[sa[i-1]]+(s[sa[i]]!=s[sa[i-1]]);
for (int j=1;rank[sa[n]]!=n;j<<=1){
for (int i=1;i<=n;i++) a[i]=rank[i],b[i]=rank[i+j];
for (int i=0;i<=n;i++) cnta[i]=cntb[i]=0;
for (int i=1;i<=n;i++) cnta[a[i]]++,cntb[b[i]]++;
for (int i=1;i<=n;i++) cnta[i]+=cnta[i-1],cntb[i]+=cntb[i-1];
for (int i=n;i;i--) tsa[cntb[b[i]]--]=i;
for (int i=n;i;i--) sa[cnta[a[tsa[i]]]--]=tsa[i];
rank[sa[1]]=1;
for (int i=2;i<=n;i++)
rank[sa[i]]=rank[sa[i-1]]+(a[sa[i]]!=a[sa[i-1]] || b[sa[i]]!=b[sa[i-1]]);
}
}
void get_H(){
ll len=0;
for (int i=1;i<=n;i++){
if (len) len--;
while (s[i+len]==s[sa[rank[i]-1]+len]) len++;
H[rank[i]]=len;
}
}
void pre_ST(){
for (int i=1;i<=n;i++) st[i][0]=i;
for (int j=1;j<=19;j++)
for (int i=1;i<=n;i++)
if (i+(1<<j)-1>n) break; else
if (H[st[i][j-1]]<H[st[i+(1<<(j-1))][j-1]]) st[i][j]=st[i][j-1];
else st[i][j]=st[i+(1<<(j-1))][j-1];
}
void pre_SUM(){
sum[0]=0LL;
for (int i=1;i<=n;i++)
sum[i]=sum[i-1]+n-sa[i]+1;
}
int query_ST(int l,int r){
int k=(double)log(r-l+1)/(double)log(2);
int t1=st[l][k],t2=st[r-(1<<k)+1][k];
return H[t1]<H[t2]?t1:t2;
}
void solve(ll l,ll r){
if (l>=r) return;
int m=query_ST(l+1,r);
ll s1=sum[m-1]-sum[l-1],s2=sum[r]-sum[m-1];
ans+=(r-m+1)*s1+(m-l)*s2-2LL*(m-l)*(r-m+1)*H[m];
solve(l,m-1);
solve(m,r);
}
int main(){
scanf("%s",s+1);
n=strlen(s+1);
get_SA(),get_H();
pre_ST(),pre_SUM();
ans=0LL;
solve(1LL,n);
printf("%lld\n",ans);
return 0;
}