BZOJ3238[Ahoi2013]差异
题目描述
n<=500000
n
<=
500000
,都是小写字母
输入
一行,一个字符串S
输出
一行,一个整数,表示所求值
Solution
公式的前两项可以化简
∑ni=1∑nj=i+1len(Ti)+len(Tj)=(n−1)∗n∗(n+1)2 ∑ i = 1 n ∑ j = i + 1 n l e n ( T i ) + l e n ( T j ) = ( n − 1 ) ∗ n ∗ ( n + 1 ) 2 ;
然后只要求出所有后缀的
lcp
l
c
p
用
(n−1)∗n∗(n+1)2
(
n
−
1
)
∗
n
∗
(
n
+
1
)
2
减一下就好
求一遍后缀数组
O(nlogn)
O
(
n
l
o
g
n
)
然后
n2
n
2
枚举后缀? naive!
后缀数组求出来以后就知道了height
问题就转化成了求所有区间最小值的和
因为
lcp(Ti,Tj)=min(height(k))j<=k<=i,rank[j]<=rank[j] l c p ( T i , T j ) = m i n ( h e i g h t ( k ) ) j <= k <= i , r a n k [ j ] <= r a n k [ j ]
所以对于一些lcp相同的区间我们不必多次求
我们可以用单调栈分别从左从右延伸
区间数量怎么表示?
其实左区间长度乘以右区间长度就好
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef int mian;
#define int long long
#define maxn 1000010
int t1[maxn],t2[maxn];
char str[maxn];
int s[maxn],height[maxn],rak[maxn],sa[maxn];
int c[maxn],l[maxn],r[maxn],stk[maxn];
void SA(int n){
int *x=t1,*y=t2;
int m=100000;
for(int i=1;i<=m;++i) c[i]=0;
for(int i=1;i<=n;++i) c[x[i]=s[i]]++;
for(int i=1;i<=m;++i) c[i]+=c[i-1];
for(int i=n;i;--i) sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1){
int p=0;
for(int i=n-k+1;i<=n;++i) y[++p]=i;
for(int i=1;i<=n;++i) if(sa[i]>k) y[++p]=sa[i]-k;
for(int i=1;i<=m;++i) c[i]=0;
for(int i=1;i<=n;++i) c[x[y[i]]]++;
for(int i=1;i<=m;++i) c[i]+=c[i-1];
for(int i=n;i;--i) sa[c[x[y[i]]]--]=y[i];
swap(x,y),p=1,x[sa[1]]=1;
for(int i=2;i<=n;++i)
x[sa[i]]=y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k]?p:++p;
if(p>=n) break;
m=p;
}
for(int i=1;i<=n;++i) rak[sa[i]]=i;
int k=0;
for(int i=1;i<=n;++i){
if(k) --k;
int j=sa[rak[i]-1];
while(s[i+k]==s[j+k])++k;
height[rak[i]]=k;
}
return ;
}
mian main(){
scanf("%s",str+1);
int n=strlen(str+1);
for(int i=1;i<=n;++i) s[i]=str[i]-'a'+1;
SA(n);
int top=0;
int ans=n*(n-1)*(n+1)>>1;
for(int i=1;i<=n;++i){
while(top&&height[stk[top]]>height[i]){
r[stk[top--]]=i-1;
}
stk[++top]=i;
}
while(top) r[stk[top--]]=n;
for(int i=n;i;--i){
while(top&&height[i]<=height[stk[top]]) l[stk[top--]]=i+1;
stk[++top]=i;
}
while(top) r[stk[top--]]=1;
for(int i=1;i<=n;++i){
ans-=height[i]*(r[i]-i+1)*((i-l[i]+1)<<1);
}
printf("%lld\n",ans);
}