题意:给出两个字符串,统计有多少长度大于等于K的相同子串。
思路:用后缀数组计算出height后,很容易想到,如果相邻的两个后缀是位于不同的字符串的,会产生height - K + 1个不相同的满足题意的子串。但是,根据题意可知,不同位置的子串也要统计,这样就会导致我们少计算很多。但是,如果是暴力计算的话,复杂度又会超时。我们就需要一个高效的统计的方法。
在这里,我们可以利用单调栈进行优化。首先,因为任意两个位置的lcp是对应的区间最小值,所以,如果当前位置的height小于栈顶的height,对于后面的后缀,能产生的最长lcp是不大于较小的height的,所以,即使较大的height的值很大,但是它能贡献的将变成小的height的大小。
所以,整体思路是这个样子的:我们对于A串,扫描B串中能产生多少满足题意的子串。再反过来扫描。在扫描的过程中,用单调栈进行计数。
注意:需要十分注意的是height和sa的关系,当前位置i的后缀产生的lcp是记录在height[i+1]中,所以,height[i]对应的后缀是i。
代码如下:
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
static const int maxn =1001000;//three times of length
int rk[maxn];//0 - n-1
int sa[maxn];//1 - n
int height[maxn];//1 - n
int wa[maxn],wb[maxn],wv[maxn],ws[maxn];
#define F(x) ((x)/3 + ((x)%3 == 1 ? 0:tb))
#define G(x) ((x) < tb ? (x)*3+1 : ((x)-tb)*3 + 2)
int c0(int *r, int a, int b){
return r[a] == r[b] && r[a+1] == r[b+1] && r[a+2] == r[b+2];
}
int c12(int k, int *r, int a,int b){
if (k == 2)
return r[a] < r[b] || r[a] == r[b] && c12(1,r,a+1,b+1);
else
return r[a] < r[b] || r[a] == r[b] && wv[a+1] < wv[b+1];
}
void radix_sort(int *r, int *a,int *b,int n,int m) {
int i;
for (i = 0; i < n; i++) wv[i] = r[a[i]];
for (i = 0; i < m; i++) ws[i] = 0;
for (i = 0; i < n; i++) ws[wv[i]]++;
for (i = 1; i < m; i++) ws[i] += ws[i-1];
for (i = n-1; i >= 0; i--) b[--ws[wv[i]]] = a[i];
return ;
}
void dc3(int *r,int *sa,int n, int m){
int i,j,*rn = r + n, *san = sa + n;
int ta = 0, tb = (n+1)/3,tbc = 0, p;
r[n] = r[n+1] = 0;
for(i = 0; i < n; i++)
if(i%3 != 0) wa[tbc++] = i;
radix_sort(r+2, wa, wb, tbc, m);
radix_sort(r+1, wb, wa, tbc, m);
radix_sort( r, wa, wb, tbc, m);
for (p = 1, rn[F(wb[0])] = 0, i = 1; i < tbc; i++)
rn[F(wb[i])] = c0(r,wb[i-1],wb[i]) ? p-1 : p++;
if(p < tbc) dc3(rn, san, tbc, p);
else
for (i = 0; i < tbc; i++) san[rn[i]] = i;
for(i = 0; i < tbc; i++)
if (san[i] < tb) wb[ta++] = san[i]*3;
if(n%3 == 1) wb[ta++] = n-1;
radix_sort(r, wb, wa, ta, m);
for(i = 0; i < tbc; i++)
wv[wb[i]=G(san[i])] = i;
for(i = 0,j = 0,p = 0; i < ta && j < tbc; p++)
sa[p] = c12(wb[j]%3,r,wa[i],wb[j]) ? wa[i++] : wb[j++];
for( ; i < ta; p++) sa[p] = wa[i++];
for ( ; j < tbc; p++) sa[p] = wb[j++];
return ;
}
void calc_sa(int *r, int n, int m){//attention: 1 <= r[i] <= m
r[n] = 0;//add zero, length : n + 1
dc3(r,sa,n+1,m);
}
void calc_height(int *r,int n){
int i,j,k = 0;
for (i = 0; i < n+1; i++)
rk[sa[i]] = i;
for (i = 0; i < n; height[rk[i++]] = k)//according to rank, only n times
for (k ? k-- : 0, j = sa[rk[i]-1]; r[i+k] == r[j+k]; k++)
;
return;
}
void print(int * r, int n){
for(int i = 1; i <= n; ++i){
for(int j = sa[i]; j < n; ++j)
putchar(r[j]);
putchar('\n');
}
}
// static const int MAX = 200100;
// int p[MAX];
// int d[MAX][20];
// void rmq_init(int n){
// p[0] = -1;
// for(int i = 1; i <= n; ++i)
// p[i] = i & (i-1)?p[i-1]:p[i-1]+1;
// for(int i = 1; i <= n; ++i) d[i][0] = height[i];
// for(int j = 1; j <= p[n]; ++j)
// for(int i = 1; i + (1 << j) - 1 <= n; ++i)
// d[i][j] = min(d[i][j-1],d[i+(1<<j-1)][j-1]);
// }
// int rmp_query(int l, int r){
// int k = p[r - l + 1];
// return min(d[l][k],d[r - (1<<k) + 1][k]);
// }
// int lcp(int l, int r){//l,r is the start postion of two suffix
// l = rank[l], r = rank[r];//we should turn them to the index in sa
// if(l > r) swap(l,r);l++;
// return rmp_query(l,r);
// }
int r[200010];
char str[100010];
int stack[200010][2];
int top;
int main(void)
{
//freopen("input.txt","r",stdin);
int K;
while(scanf("%d",&K),K){
scanf("%s",str);
int n1 = strlen(str);
copy(str,str+n1,r);
r[n1] = '$';
scanf("%s",str);
int n2 = strlen(str);
copy(str,str+n2,r+n1+1);
int n = n1 + n2 + 1;
calc_sa(r,n,256);
calc_height(r,n);
long long ans = 0LL;
long long tot = 0LL;
top = 0;
for(int i = 1; i <= n; ++i){
int cnt = 0;
if(height[i] < K){
top = tot = 0;
continue;
}
if(sa[i-1] < n1) tot += height[i] - K + 1,cnt = 1;
while(top > 0 && stack[top-1][0] >= height[i]){
top--;
tot -= (stack[top][0] - height[i]) * stack[top][1];
cnt += stack[top][1];
}
stack[top][0] = height[i],stack[top][1] = cnt;
top++;
if(sa[i] > n1)ans += tot;
}
top = tot = 0;
for(int i = 1; i <= n; ++i){
if(height[i] < K){
tot = top = 0;
continue;
}
int cnt = 0;
if(sa[i-1] > n1) tot += height[i] - K + 1, cnt = 1;
while(top > 0 && stack[top-1][0] >= height[i]){
top--;
tot -= (stack[top][0] - height[i]) * stack[top][1];
cnt += stack[top][1];
}
stack[top][0] = height[i],stack[top][1] = cnt;
top++;
if(sa[i] < n1)ans += tot;
}
printf("%lld\n",ans);
}
return 0;
}