链接:http://poj.org/problem?id=3415
求两串中长度大于k的公共子串有多少个。
公共子串可以通过height求,中间分隔连接两串,将height[i]>=k进行分组,对于一组内的height[i],且sa[i]属于a串,需要找到j<i的串属于b则两串之间的公共子串有个数cnt=min(height[j->i]-k),采用单调栈维护一个栈顶最小的height[i],大于栈顶压入,小于更新。每次针对a/b串找前面的b/a串,跑两次。
//#include <bits/stdc++.h>
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <iostream>
#include <queue>
#include <cmath>
#include <string>
#include <map>
using namespace std;
#define INF 0x3f3f3f3f
using namespace std;
const int MAXN = 4e5+7;
int t1[MAXN],t2[MAXN],c[MAXN];
int sa[MAXN],ra[MAXN],height[MAXN];
char str1[MAXN],str2[MAXN];
int len1,len2,len;
int num[MAXN];
int cnt[MAXN];
bool cmp(int *r, int a, int b,int l)
{
return r[a]==r[b]&&r[a+l]==r[b+l];
}
void DA(int str[], int n, int m)
{
n++;
int i,j,p,*x=t1,*y=t2;
for(i=0; i<m; ++i)c[i]=0;
for(i=0; i<n; ++i)c[x[i]=str[i]]++;
for(i=1; i<m; ++i)c[i]+=c[i-1];
for(i=n-1; i>=0; --i)sa[ --c[ x[i] ] ]=i;
for(j=1; j<=n; j<<=1)
{
p=0;
for(i=n-j; i<n; ++i)y[p++]=i;
for(i=0; i<n; ++i)if(sa[i]>=j)y[p++]=sa[i]-j;
for(i=0; i<m; ++i)c[i]=0;
for(i=0; i<n; ++i)c[x[y[i]]]++;
for(i=1; i<m; ++i)c[i]+=c[i-1];
for(i=n-1; i>=0; --i)sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1;
x[sa[0]]=0;
for(i=1; i<n; ++i)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
m=p;
}
int k=0;
n--;
for(i=0; i<=n; ++i)ra[sa[i]]=i;
for(i=0; i<n; ++i)
{
if(k)--k;
j=sa[ra[i]-1];
while(str[i+k]==str[j+k])k++;
height[ra[i]]=k;
}
}
long long que[MAXN][2];
int main()
{
int k;
while(scanf("%d",&k)!=EOF)
{
if(k==0)break;
scanf("%s",str1);
scanf("%s",str2);
len1=strlen(str1);
len2=strlen(str2);
//strcat(str1,str2);
len=len1+len2+1;
str1[len1]='Z'+1;
for(int i=1; i<=len2; ++i)
str1[len1+i]=str2[i-1];
str1[len]=0;
for(int i=0; i<len; ++i)
num[i]=str1[i];
num[len]=0;
DA(num,len,200);
int top=0;
long long ans=0,sum=0,cnt=0;
for(int i=1; i<=len; ++i)
{
if(height[i]<k){top=sum=cnt=0;}
else
{
cnt=0;
if(sa[i-1]<len1)cnt=1,sum+=height[i]-k+1;
while(top!=0 && que[top-1][0]>=height[i])
{
top--;
sum-=1LL*(que[top][0]-height[i])*que[top][1];
cnt+=que[top][1];
}
que[top][0]=height[i];
que[top++][1]=cnt;
if(sa[i]>len1)ans+=sum;
}
}
top=0;
for(int i=1; i<=len; ++i)
{
if(height[i]<k){top=sum=cnt=0;}
else
{
cnt=0;
if(sa[i-1]>len1)cnt=1,sum+=height[i]-k+1;
while(top!=0 && que[top-1][0]>=height[i])
{
top--;
sum-=1LL*(que[top][0]-height[i])*que[top][1];
cnt+=que[top][1];
}
que[top][0]=height[i];
que[top++][1]=cnt;
if(sa[i]<len1)ans+=sum;
}
}
printf("%I64d\n",ans);
}
return 0;
}