链接
题意
给出两个字符串s1、s2和一个整数K,求有多少个长度大于K的公共子串。
思路
后缀数组的一个经典问题,对height数组分组后使用单调栈将复杂度优化至O(n),单调栈还真是强大啊。。。
整个过程是先将s1、s2串起来,用比较小的数分隔,求出后缀数组,按照K的公共前缀长度去分组,对组内每个s2后缀,前方出现的每个同组s1后缀都与其存在长度不小于K的公共前缀。之后再计算一遍组内s1后缀前方出现的各个同组s2后缀与其公共前缀,累加即可得到结果。
对height数组建st表,可以查询任意两个后缀的最长公共前缀,然而即使这么做复杂度也是O(n^2),不能满足时间要求。
在我们扫描分组的时候,可以对height维护一个“单调栈”,由于后缀i、j(rank[i] < rank[j])之间的最长公共前缀是min(height[rank[i] + 1], …, height[rank[i] + k], …, height[rank[j]],因此LCP(sa[rank[i]], sa[rank[j]])一定小于等于LCP(sa[rank[i] + 1], sa[rank[j]]),所以后入栈的后缀,需要对在它前方入栈且LCP大于它的后缀的贡献做消减,使整个栈满足“顺序性”和“单调性”。
PS:其实我自己写着写着也写懵逼了,看代码吧。。。这道题我几乎是看别人博客看到把代码背下来了才懂这个单调栈是怎么优化的。
代码
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long lint;
#define maxn 200200
bool cmp(int *r, int a, int b, int l)
{ return r[a] == r[b] && r[a + l] == r[b + l]; }
int ta[maxn], tb[maxn], bk[maxn];
void da(int *r, int *sa, int n, int m)
{
int i, j, p, *x = ta, *y = tb, *t;
for(i = 0; i < m; i++) bk[i] = 0;
for(i = 0; i < n; i++) bk[x[i] = r[i]]++;
for(i = 1; i < m; i++) bk[i] += bk[i-1];
for(i = 0; i < n; i++) sa[--bk[x[i]]] = i;
for(j = 1, p = 1; p < n; j *= 2, m = p)
{
for(p = 0, i = n - j; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
for(i = 0; i < m; i++) bk[i] = 0;
for(i = 0; i < n; i++) bk[x[i]]++;
for(i = 1; i < m; i++) bk[i] += bk[i-1];
for(i = n-1; i >= 0; i--) sa[--bk[x[y[i]]]] = y[i];
for(t = x, x = y, y = t, x[sa[0]] = 0, p = 1, i = 1; i < n; i++)
x[sa[i]] = cmp(y, sa[i-1], sa[i], j) ? p - 1 : p++;
}
}
int Rank[maxn], SA[maxn], Height[maxn];
void calheight(int *r, int n)
{
for(int i = 0; i < n; i++) Rank[SA[i]] = i;
for(int i = 0, k = 0; i < n; i++)
{
k ? k-- : 0;
if(Rank[i])
while(r[i + k] == r[SA[Rank[i] - 1] + k])
k++;
Height[Rank[i]] = k;
}
}
struct _node
{
int h, t;
} _stack[maxn];
char s1[maxn >> 1], s2[maxn >> 1];
int r[maxn], n, m;
int main()
{
int K;
while((cin >> K) && K)
{
scanf("%s%s", s1, s2);
n = 0;
for(int i = 0; s1[i]; i++)
r[n++] = s1[i];
r[m = n++] = 1;
for(int i = 0; s2[i]; i++)
r[n++] = s2[i];
r[n++] = 0;
da(r, SA, n, 1<<8);
calheight(r, n);
lint o = 0, tot = 0;
for(int i = 3, top = 0, cnt; i < n; i++)
{
if(Height[i] < K) { tot = top = 0; continue; }
cnt = 0;
if(SA[i - 1] < m) { tot += Height[i] - K + 1; cnt++; }
while(top > 0 && Height[i] < _stack[top - 1].h)
{
tot -= _stack[top - 1].t * (_stack[top - 1].h - Height[i]);
cnt += _stack[top - 1].t;
top--;
}
_stack[top].h = Height[i];
_stack[top++].t = cnt;
if(SA[i] > m) o += tot;
}
tot = 0;
for(int i = 3, top = 0, cnt; i < n; i++)
{
if(Height[i] < K) { tot = top = 0; continue; }
cnt = 0;
if(SA[i - 1] > m) { tot += Height[i] - K + 1; cnt++; }
while(top > 0 && Height[i] < _stack[top - 1].h)
{
tot -= _stack[top - 1].t * (_stack[top - 1].h - Height[i]);
cnt += _stack[top - 1].t;
top--;
}
_stack[top].h = Height[i];
_stack[top++].t = cnt;
if(SA[i] < m) o += tot;
}
cout << o << endl;
}
return 0;
}