题意:求两个字符串长度不小于 K 的公共子串个数
分析:将两个字符串 A 和 B 连接起来,中间用其它字符隔开,求一遍后缀数组,根据 sa 数组的性质,设sa[l...r]为连续一段
公共子串长度>=K的后缀,那么对于某个 i (l<=i<=r) 若 sa[i] 属于 A ,则它的贡献可以是:
对于每个 A ,我们都这样统计一遍,则我们计算了所有的按照sa数组顺序,B 在前 A 在后,两者子串大于等于 K 的情况。
现在还缺少 B 在后 A在前的情形,同理,对于某个 i (l<=i<=r) 若 sa[i] 属于 B ,则它的贡献可以是:
这样子所有的方案就出来了,但是对于每次贡献的计算,暴力往前找累加答案肯定不行,需要优化,我们知道sa[i]和sa[j](假设i<j) 的最长公共子串 lcp(i,j)=min(ht[i+1...j]),那么 j 不变的情况下 lcp(i+1...j-1,j) 肯定是非递减的,那么就可以用单调栈维护这么一个非递减序列,对于重复元素就压缩一下记录数量即可 (可能有点抽象,可以看代码理解Q_v_Q)
代码:
#include<stack>
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
#define fi first
#define sd second
typedef pair<int,int> P;
typedef long long ll;
const int N = 2E5+10;
char a[N],b[N];
int n,s[N],sa[N],rk[N],oldrk[N<<1];
int cnt[N],ht[N],px[N],id[N];
bool cmp(int x,int y,int w){
return oldrk[x]==oldrk[y] && oldrk[x+w]==oldrk[y+w];
}
void da(int s[],int n,int m){
int i,p=0,w,k;
for(i=1;i<=n;i++) ++cnt[rk[i] = s[i]];
for(i=1;i<=m;i++) cnt[i] += cnt[i-1];
for(i=n;i>=1;i--) sa[cnt[rk[i]]--] = i;
for(w=1;w<n;w<<=1,m=p){
for(p=0,i=n;i>n-w;i--) id[++p]=i;
for(i=1;i<=n;i++)
if(sa[i]>w) id[++p]=sa[i]-w;
memset(cnt,0,sizeof(cnt));
for(i=1;i<=n;i++) ++cnt[px[i] = rk[id[i]]];
for(i=1;i<=m;i++) cnt[i] += cnt[i-1];
for(i=n;i>=1;i--) sa[cnt[px[i]]--] = id[i];
memcpy(oldrk,rk,sizeof(rk));
for(p=0,i=1;i<=n;i++)
rk[sa[i]]=cmp(sa[i],sa[i-1],w)?p:++p;
}
for(i=1,k=0;i<=n;i++){
if(k) --k;
while(s[i+k]==s[sa[rk[i]-1]+k]) ++k;
ht[rk[i]]=k;
}
}
P H[N]; //单调栈,sd 记录非递减序列值,fi 记录数量;
int K,Y;
void work(){
ll ANS=0;
ll tot=0,len=0;
for(int i=2;i<=n;i++){
if(ht[i]<K) tot=len=0;
else{
int cnt=0;
if(sa[i-1]<Y) tot+=ht[i]-K+1,cnt++;
while(len>0&&H[len].sd>=ht[i]){ //相当于统计相同值的数量;
P top=H[len]; len--;
tot-=(top.sd-ht[i])*top.fi;
cnt+=top.fi;
}
H[++len]=P(cnt,ht[i]);
if(sa[i]>Y) ANS+=tot;
}
}
tot=0,len=0;
for(int i=2;i<=n;i++){
if(ht[i]<K) tot=len=0;
else{
int cnt=0;
if(sa[i-1]>Y) tot+=ht[i]-K+1,cnt++;
while(len>0&&H[len].sd>=ht[i]){
P top=H[len]; len--;
tot-=(top.sd-ht[i])*top.fi;
cnt+=top.fi;
}
H[++len]=P(cnt,ht[i]);
if(sa[i]<Y) ANS+=tot;
}
}
printf("%lld\n",ANS);
}
int main()
{
while(scanf("%d",&K)&&K){
memset(cnt,0,sizeof(cnt));
scanf("%s%s",a+1,b+1);
n=strlen(a+1);
for(int i=1;i<=n;i++) s[i]=a[i];
Y=n+1;
n=strlen(b+1);
for(int i=1;i<=n;i++) s[i+Y]=b[i];
s[Y]='#';
n+=Y;
da(s,n,128);
work();
}
}