传送门
思路:
偶然翻到的一个题
苦思冥想算法之时……
旁边不(jing)会(tong)后缀数组的聪爷爷:这不是后缀数组吗?
赶紧来练一下遗忘的后缀数组(然后手打板子又错了,只能回到博客上重新翻一波以前写的)
设
lena
,
lenb
为字符串a,b的长度
ai
,
bi
分别表示字符串a,b的后缀
[i..n]
那么答案就是
∑lenai=1∑lenbj=1lcp(ai,bj)
这个式子朴素做是
O(n3)
的
把两个串连一起然后再求后缀数组是可以做到
O(n2)
,因为lcp是可以
O(1)
求的
但仍然不能使人满意
那怎么办?
之前写过类似的题,好像用的是单调栈
这次没这么做……
考虑把后缀按照rank排序后,height值是有大有小的(这不是废话吗)
比如我们想求得rank为
[l,r]
的后缀中对答案的贡献
我们可以求出[l+1,r]中最小的height所在的位置mid
也就是说
heightmid<=heighti,i=l+1..r
可以统计答案就是[l,mid-1]的a后缀个数×[mid,r]的b后缀个数+[l,mid-1]的b后缀个数×[mid,r]的a后缀个数,统计后缀个数可以用前缀和处理一下
然后可以再递归子问题[l,mid-1],[mid,r]了
相当于从小到大枚举height
复杂度
O(n)
所以总复杂度就是后缀数组的建立与预处理ST表,即
O(nlog2n)
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#define LL long long
#define M 200005
using namespace std;
int la,lb,lc;
char a[M],b[M];
int c[M<<1],cnt[M<<1],rank[M<<1],sa[M<<1],id[M<<1],tmp[M<<1],height[M<<1],ST[19][M<<1],sum1[M<<1],sum2[M<<1];
void SA(int len,int up)
{
int p=0,d=1,*rk=rank,*t=tmp;
for (int i=0;i<len;++i) ++cnt[rk[i]=c[i]];
for (int i=1;i<up;++i) cnt[i]+=cnt[i-1];
for (int i=len-1;i>=0;--i) sa[--cnt[rk[i]]]=i;
for (;;)
{
for (int i=len-d;i<len;++i) id[p++]=i;
for (int i=0;i<len;++i)
if (sa[i]-d>=0) id[p++]=sa[i]-d;
for (int i=0;i<up;++i) cnt[i]=0;
for (int i=0;i<len;++i) ++cnt[t[i]=rk[id[i]]];
for (int i=1;i<up;++i) cnt[i]+=cnt[i-1];
for (int i=len-1;i>=0;--i) sa[--cnt[t[i]]]=id[i];
swap(t,rk);
p=1;
rk[sa[0]]=0;
for (int i=0;i<len-1;++i)
if (sa[i]+d<len&&sa[i+1]+d<len&&t[sa[i]]==t[sa[i+1]]&&t[sa[i]+d]==t[sa[i+1]+d])
rk[sa[i+1]]=p-1;
else
rk[sa[i+1]]=p++;
if (p==len) return;
d<<=1;up=p;p=0;
}
}
void Height()
{
for (int i=1;i<=lc;i++) rank[sa[i]]=i;
int x,k=0;
for (int i=0;i<lc;++i)
{
k=max(k-1,0);
x=sa[rank[i]-1];
while (c[x+k]==c[i+k]) ++k;
height[rank[i]]=k;
}
}
LL solve(int l,int r)
{
if (l>=r) return 0;
int p=log2(r-l),mid;
mid=height[ST[p][l+1]]>height[ST[p][r-(1<<p)+1]]?ST[p][r-(1<<p)+1]:ST[p][l+1];
return solve(l,mid-1)+solve(mid,r)+((LL)(sum1[mid-1]-sum1[l-1])*(sum2[r]-sum2[mid-1])+(LL)(sum2[mid-1]-sum2[l-1])*(sum1[r]-sum1[mid-1]))*height[mid];
}
main()
{
scanf("%s",a);scanf("%s",b);
la=strlen(a);lb=strlen(b);
for (int i=0;i<la;++i)
c[i]=a[i]-'a'+1;
c[la]='{'-'a'+1;
for (int i=la+1;i<=lb+la;++i)
c[i]=b[i-la-1]-'a'+1;
SA(la+lb+2,29);
lc=la+lb+1;
Height();
for (int i=1;i<=la+lb;++i)
sum1[i]=sum1[i-1]+(sa[i]<la),
sum2[i]=sum2[i-1]+(sa[i]>la);
for (int i=1;i<=lc;++i) ST[0][i]=i;
for (int i=1;1<<i<=lc;++i)
for (int j=1;(1<<i)+j-1<=lc;++j)
ST[i][j]=(height[ST[i-1][j]]>height[ST[i-1][j+(1<<i-1)]]?ST[i-1][j+(1<<i-1)]:ST[i-1][j]);
printf("%lld",solve(1,lc-1));
}