题意:给三个串s1,s2,s3,对于每个长度L(1<=L<=min(length(s1,s2,s3)))求有多少个三元组<p1,p2,p3>使得s1[p1...p1+L-1]==s2[p2..p2+L-1]==s3[p3..p3+L-1],求出所有L对应的答案对1e9+7取模。
题解:
先把串连在一起,记录每个位置所属的串,跑一下后缀数组求出height数组。
如果从位置pos开始的最大公共子串长L,那么所有从pos开始的长度小于L的子串都是符合条件的。
这里可以按height从大到小的顺序,一边计算一边用并查集合并区间,对于要合并的两个位置u1,u2,首先他们一定是相邻。用sum[type][pos]来维护每个位置pos所包含的s1,s2,s3的数量,那么u1与u2合并以后(以u1为fa),对答案的贡献就是
sum[0][u1]*sum[1][u1]*sum[2][u1],因为涉及重复统计,那么在相加之前,要先减掉u1和u2中的贡献。
#include<bits/stdc++.h>
#define rint register int
#define inv inline void
#define ini inline int
#define maxn 3000050
using namespace std;
typedef long long ll;
const ll mod=1000000007;
char s[maxn],t[maxn];
int y[maxn],x[maxn],c[maxn],sa[maxn],rk[maxn],height[maxn],wt[30];
int n,m;
inv putout(int x) {
if(!x) {
putchar(48);
return;
}
rint l=0;
while(x) wt[++l]=x%10,x/=10;
while(l) putchar(wt[l--]+48);
}
inv get_SA() {
for (rint i=1; i<=n; ++i) ++c[x[i]=s[i]];
for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
for (rint i=n; i>=1; --i) sa[c[x[i]]--]=i;
for (rint k=1; k<=n; k<<=1) {
rint num=0;
for (rint i=n-k+1; i<=n; ++i) y[++num]=i;
for (rint i=1; i<=n; ++i) if (sa[i]>k) y[++num]=sa[i]-k;
for (rint i=1; i<=m; ++i) c[i]=0;
for (rint i=1; i<=n; ++i) ++c[x[i]];
for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
for (rint i=n; i>=1; --i) sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
x[sa[1]]=1;
num=1;
for (rint i=2; i<=n; ++i)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k]) ? num : ++num;
if (num==n) break;
m=num;
}
}
inv get_height() {
rint k=0;
for (rint i=1; i<=n; ++i) rk[sa[i]]=i;
for (rint i=1; i<=n; ++i) {
if (rk[i]==1) continue;//第一名height为0
if (k) --k;//h[i]>=h[i-1]-1;
rint j=sa[rk[i]-1];
while (j+k<=n && i+k<=n && s[i+k]==s[j+k]) ++k;
height[rk[i]]=k;//h[i]=height[rk[i]];
}
}
int p[maxn],belong[maxn],fa[maxn];
ll sum[4][maxn],ans[maxn];
char ss[maxn];
bool cmp(int x,int y){
return height[x]>height[y];
}
inline int findfa(int u){
if(fa[u]==u) return u;
else return fa[u]=findfa(fa[u]);
}
int main() {
scanf("%s",ss+1);
int len=strlen(ss+1);
int aa=len;
n=0;
for(int i=1;i<=len;i++) belong[++n]=0,s[n]=ss[i];
s[++n]=1;
belong[n]=-1;
scanf("%s",ss+1);
len=strlen(ss+1);aa=min(aa,len);
for(int i=1;i<=len;i++) belong[++n]=1,s[n]=ss[i];
s[++n]=2;
belong[n]=-1;
scanf("%s",ss+1);
len=strlen(ss+1);aa=min(aa,len);
for(int i=1;i<=len;i++) belong[++n]=2,s[n]=ss[i];
s[++n]=3;
belong[n]=-1;
m=200;
get_SA();
get_height();
for(int i=1;i<=n;i++){
//printf("@%d\n",height[i]);
p[i]=i;
fa[i]=i;
sum[0][i]=sum[1][i]=sum[2][i]=0;
if(belong[i]==0) sum[0][i]=1;
if(belong[i]==1) sum[1][i]=1;
if(belong[i]==2) sum[2][i]=1;
}
sort(p+1,p+1+n,cmp);
int j=1;
ll tmp=0;
for(int i=aa;i>=1;i--){
while(j<=n&&height[p[j]]>=i){
//printf("!%d %d\n",sa[p[j]],p[j]);
int l=findfa(sa[p[j]-1]);
int r=findfa(sa[p[j]]);
tmp=(tmp+mod-(sum[0][l]*sum[1][l]%mod*sum[2][l]%mod))%mod;
tmp=(tmp+mod-(sum[0][r]*sum[1][r]%mod*sum[2][r]%mod))%mod;
sum[0][l]+=sum[0][r];
sum[1][l]+=sum[1][r];
sum[2][l]+=sum[2][r];
tmp=(tmp+(sum[0][l]*sum[1][l]%mod*sum[2][l]%mod))%mod;
fa[r]=l;
j++;
}
tmp%=mod;
ans[i]=tmp;
}
for(int i=1;i<=aa;i++) printf("%lld ",ans[i]);
}