字符串T的一个子串定义为:
T(i,k)=T[i]T[i+1]...T[i+k-1],1<=i<=i+k-1<=|T|。
给出两个字符串A,B,一个整数K,我们定义S是一个由三元组构成的集合:
S={(i,j,k)|k>=K,A(i,k)=B(j,k)}。
你需要对于A,B,K,给出|S|的值。
【输入格式】
输入包含多组数据。
每组数据的第一行是一个正整数K,接下来是两行,分别是字符串A和B。输入结束标志为K=0。
【输出格式】
对每组数据输出一行一个整数,即|S|。
【样例输入】
2
aababaa
abaabaa
1
xx
xx
0
【样例输出】
22
5
【提示】
1<=|A|,|B|<=10^5
1<=K<=min{|A|,|B|}
A和B中仅含小写字母。
把两个串连接起来,求一个后缀数组。
在rank数组中从前往后枚举起点,对于每个枚举的起点,都暴力的往后扫,扫的过程中维护一个height的最小值。每到一个点的时候,如果这个点跟起点不属于一个串,就将答案加上当前的最小值,这样是O(n2)的
考虑这个还能怎么算。可以发现我们是维护height的最小值。那么我们可以按照height从大到小的顺序扫,这样每次需要用的就是当前的height。
扫的过程中用并查集维护一下每个串分别对哪些串有贡献的(也就是height数组的贡献)。
用乘法原理算一下当前的height会有多少贡献。就是用当前的height乘上这个串和上一个串分别对于两个两个不同的原串的乘积的和。
需要特判当前串是否大于等于k
代码
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define LL long long const int N=400010; LL ans; char ss[N]; int lk; int n,m,len[2],sa[N],c[N],rank[N],height[N],t1[N],t2[N],s[N],fa[N],st[N],en[N],a[N]; inline bool cmp(int *y,int p,int q,int k){ int o0=p+k>=n?-1:y[p+k]; int o1=q+k>=n?-1:y[q+k]; return o0==o1&&y[p]==y[q]; } inline void build_sa(){ int i,k,p,*x=t1,*y=t2; for(m=28,i=0;i<m;++i) c[i]=0; for(i=0;i<n;++i) ++c[x[i]=s[i]]; for(i=1;i<m;++i) c[i]+=c[i-1]; for(i=n-1;i>=0;--i) sa[--c[x[i]]]=i; for(k=1;k<=n;k<<=1){ for(p=0,i=n-k;i<n;++i) y[p++]=i; for(i=0;i<n;++i) if(sa[i]>=k) y[p++]=sa[i]-k; for(i=0;i<m;++i) c[i]=0; for(i=0;i<n;++i) ++c[x[y[i]]]; for(i=1;i<m;++i) c[i]+=c[i-1]; for(i=n-1;i>=0;--i) sa[--c[x[y[i]]]]=y[i]; swap(x,y); m=1;x[sa[0]]=0; for(i=1;i<n;++i) x[sa[i]]=cmp(y,sa[i],sa[i-1],k)?m-1:m++; if(m>=n) break; } } inline void build_height(){ int k=0,j; for(int i=0;i<n;++i) rank[sa[i]]=i; for(int i=0;i<n;++i){ if(!rank[i]) continue; k--; if(k<0) k=0; j=sa[rank[i]-1]; while(s[i+k]==s[j+k]) k++; height[rank[i]]=k; } } inline bool CMP(int x,int y){ return height[x]>height[y]; } inline int find(int x){ if(x!=fa[x]) fa[x]=find(fa[x]); return fa[x]; } int mk; inline void calc(int x){ if(!x) return ; int r1=find(x),r2=find(x-1); int res=height[x]>=mk?height[x]-mk+1:0; ans+=(LL)(st[r1]*en[r2]+st[r2]*en[r1])*(LL)res; st[r1]+=st[r2];en[r1]+=en[r2];fa[r2]=r1; } int main(){ freopen("commonsubstrings.in","r",stdin); freopen("commonsubstrings.out","w",stdout); while(scanf("%d",&mk)&&mk!=0){ memset(s,0,sizeof(s)); memset(c,0,sizeof(c)); ans=0; scanf("%s",ss); len[0]=strlen(ss); for(int i=0;i<len[0];++i) s[i]=ss[i]-'a'+1; scanf("%s",ss); len[1]=strlen(ss); int i; for(s[len[0]]=27,i=0;i<len[1];i++) s[i+len[0]+1]=ss[i]-'a'+1; n=len[0]+len[1]+1; m=40; build_sa(); build_height(); for(i=0;i<n;++i){ a[i]=fa[i]=i; st[i]=(sa[i]<len[0]); en[i]=1-st[i]; } sort(a,a+n,CMP); for(i=0;i<n;++i) calc(a[i]); printf("%lld\n",ans); } return 0; }