题意 :
给定两个字符串 A 和 B,求长度不小于 k 的公共子串的个数(可以相同)。
样例 1:A=“xx”,B=“xx”,k=1,长度不小于 k 的公共子串的个数是 5。
样例 2:
A = “aababaa”,B = “abaabaa”,k=2,长度不小于 k 的公共子串的个数是22。
思路 : 基本思路是计算 A 的所有后缀和 B 的所有后缀之间的最长公共前缀的长度,把最长公共前缀长度不小于 k 的部分全部加起来。比如样例2中 A的后缀4和B的后缀1的LCP为baa = 3 , 那么这里就有2个公共子串。因为起点是不一样的,所以最后结果就是直接把所有LCP( i , j ) >= k 的部分求和就可以了。当然,如果直接搞的话,是要O( n*n ) 的,铁定要超时。比较好的做法是,将连接起来的字符串根据 K ,讲height 分组 , 那么每组内 , 所有的LCP都是大于等于k , 那么我们要做的快速的统计出每组中长度不小于k的公共子串的个数。
那么对于每一组,用一个单调栈去维护,对于每个sa[i],我们把对应的height[i+1]的值加入到单调栈中,保证栈中的height值是递增的。每次出现一个B字符串中的后缀,就统计单调栈中有多少个A字符串中的子串长度大于等于k。那么如何统计呢?我维护了一个cnt域和sum域( 这个方法可能比较麻烦 ... 不过看了其他人的报告 ... 没看懂怎么处理的,就自己想了个笨方法 ) ,cnt域表示一个较小的height加入到单调栈中的时候删掉的元素的个数( 因为较小的height加入,当时比这个height大的就最多只能由当前height这个大小了,所以要把这些删掉的元素当成大小为height算),以及sum域表示从栈底到该元素A的子串中有多少个长度大于等于k的公共子串。
cnt就直接在元素加入栈的时候维护,sum 的维护根据前一个元素的sum来维护,即 sum = s[top-1].sum + cnt * ( height[i+1] - k + 1 )
为了方便计算,我维护的height值,其实是height-k+1
当遇到一个B字符串中的元素时,ans 加上 sum[top-1].sum 即可
当然,这个过程还要再做一遍,第二遍栈里面统计的是字符串B中的信息
#include <stdio.h>
#include <string.h>
#include <string>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
#define maxn 200205
int wa[maxn],wb[maxn],wv[maxn],wt[maxn];
typedef long long LL ;
int cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void da(int *r,int *sa,int n,int m){
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<m;i++) wt[i]=0;
for(i=0;i<n;i++) wt[x[i]=r[i]]++;
for(i=1;i<m;i++) wt[i]+=wt[i-1];
for(i=n-1;i>=0;i--) sa[--wt[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p){
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) wt[i]=0;
for(i=0;i<n;i++) wt[wv[i]]++;
for(i=1;i<m;i++) wt[i]+=wt[i-1];
for(i=n-1;i>=0;i--) sa[--wt[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
int Rank[maxn],height[maxn];
void calheight(int *r,int *sa,int n){
int i , j , k = 0 ;
for( i=1 ; i<=n ; i++ ) Rank[sa[i]]=i;
for(i=0;i<n;i++) {
if(k)k--;
int j = sa[Rank[i]-1];
while(r[i+k]==r[j+k]) k++ ;
height[Rank[i]] = k ;
}
return;
}
int r[maxn] , sa[maxn] ;
char str[maxn] ;
struct Node{
int cnt ;
int height ;
LL sum ;
Node(){}
Node( int _cnt , int _height , LL _sum ){
cnt = _cnt ;
height = _height ;
sum = _sum ;
}
};
int top ;
Node s[maxn] ;
int main(){
int k ;
while( scanf( "%d" , &k ) != EOF ) {
LL ans = 0 ;
if( k == 0 ) break;
scanf( "%s" , str ) ;
int len1 = strlen( str ) ;
str[len1] = '$' ; str[len1+1] = 0 ;
scanf( "%s" , str + len1 + 1 ) ;
int len = strlen( str ) ;
for( int i = 0 ; i < len ; i ++ ) r[i] = str[i] ; r[len] = 0 ;
da( r , sa , len + 1 , 200 ) ;
calheight( r , sa , len ) ;
top = 0 ;
int cnt = 0 ;
if( sa[1] < len1 ) cnt ++ ;
if( sa[1] != len1 ) {
s[top++] = Node( cnt , height[2] - k + 1 , height[2] - k + 1 ) ;
}
for( int i = 2 ; i <= len ; i ++ ) {
if( sa[i] > len1 && top > 0 ) ans += max( (LL)0 , s[top-1].sum ) ;
if( height[i] < k ) {
top = 0 ;
}
if( sa[i] < len1 ) cnt = 1 ;
else if( sa[i] > len1 )cnt = 0 ;
else continue ;
while( top > 0 && s[top-1].height >= height[i+1] - k + 1 ) {
top -- ;
cnt += s[top].cnt ;
}
LL tmp = (LL)cnt * ( height[i+1] - k + 1 ) ;
if( top > 0 ) tmp += s[top-1].sum ;
s[top++] = Node( cnt , height[i+1] - k + 1 , tmp ) ;
}
top = 0 ;
cnt = 0 ;
if( sa[1] > len1 ) cnt ++ ;
if( sa[1] != len1 ) {
s[top++] = Node( cnt , height[2] - k + 1 , height[2] - k + 1 ) ;
}
for( int i = 2 ; i <= len ; i ++ ) {
if( sa[i] < len1 && top > 0 ) ans += max( (LL)0 , s[top-1].sum ) ;
if( height[i] < k ) {
top = 0 ;
}
if( sa[i] > len1 ) cnt = 1 ;
else if( sa[i] < len1 )cnt = 0 ;
else continue ;
while( top > 0 && s[top-1].height >= height[i+1] - k + 1 ) {
top -- ;
cnt += s[top].cnt ;
}
LL tmp = (LL)cnt * ( height[i+1] - k + 1 ) ;
if( top > 0 ) tmp += s[top-1].sum ;
s[top++] = Node( cnt , height[i+1] - k + 1 , tmp ) ;
}
printf( "%lld\n" , ans ) ;
}
return 0 ;
}