链接
POJ - 3415
题意:给出两个字符串s1、s2和一个整数K,求有多少个长度大于K的公共子串。
思路:
参考题解:
现将s = s1 + ‘$’ + s2;(做出后缀数组)
有后缀数组height的特性可知可将排名1~n的后缀串通过是否公共子串大于k来分组
我们只需求出每一组中的长度大于K的公共子串, 在相加即可得到答案;
对于每一组
朴素方法:时间复杂度
O
(
n
2
)
O(n^2)
O(n2)
枚举组内的每一个s1中的后缀再枚举在该组中的每一个s2中的后缀, 答案为(height[i] - k + 1) * (height[j] - k + 1)
单调栈优化:时间复杂度
O
(
n
)
O(n)
O(n)
对于一个组, 我们先考虑每一个s2的后缀和排名在它前面的s1的后缀组成的公共子串个数, 再考虑每个s1后缀和排名在它前面的s2的后缀组成的公共子串个数, 这样就考虑了所有情况,
考虑每一个s2的后缀和排名在它前面的s1的后缀组成的公共子串个数:对于找到的一个s2的一个后缀
i
i
i, 和排在它前面的s1的后缀
j
j
j, 公共子串数量 == (height[i] - k + 1) * min(height[j] - k + 1, height[j + 1] - k + 1, , height[i] - k + 1]) ,
单调栈解释:如果当前的height[i]大于栈顶就加入栈,反之就弹出栈, 但弹出栈还要记录这个弹出的元素可能还会产生的贡献, 我们将这个贡献记录到height[i]头上, 这个贡献和height[i]]大小相同, 因为当前的height[i]小于它, 根据height数组的含义, 它和后来的串的1公共子串长度必然小于等于height[i]
细节见于代码注释
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
using namespace std;
#define ll long long
const int N = 300010;
int n, m;
char s[N], a[N];
int k;
int sa[N], x[N], y[N], c[N], rk[N], height[N], base[N], f[N][30];
int id[N];
void get_sa()
{
for (int i = 1; i <= n; i ++ ) c[x[i] = s[i]] ++ ;
for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1];
for (int i = n; i; i -- ) sa[c[x[i]] -- ] = i;
for (int k = 1; k <= n; k <<= 1)
{
int num = 0;
for (int i = n - k + 1; i <= n; i ++ ) y[ ++ num] = i;
for (int i = 1; i <= n; i ++ )
if (sa[i] > k)
y[ ++ num] = sa[i] - k;
for (int i = 1; i <= m; i ++ ) c[i] = 0;
for (int i = 1; i <= n; i ++ ) c[x[i]] ++ ;
for (int i = 2; i <= m; i ++ ) c[i] += c[i - 1];
for (int i = n; i; i -- ) sa[c[x[y[i]]] -- ] = y[i], y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for (int i = 2; i <= n; i ++ )
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++ num;
if (num == n) break;
m = num;
}
}
void get_height()
{
for (int i = 1; i <= n; i ++ ) rk[sa[i]] = i;
for (int i = 1, k = 0; i <= n; i ++ )
{
if (rk[i] == 1) continue;
if (k) k -- ;
int j = sa[rk[i] - 1];
while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k ++ ;
height[rk[i]] = k;
}
}
void init_rmq()
{
base[0] = -1;
for(int i = 1; i <= n; i ++)
{
f[i][0] = height[i];
base[i] = base[i>>1] + 1;
}
for(int j = 1; j <= 18; j ++)
{
for(int i = 1; i + (1 << (j - 1)) <= n; i++)
{
f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}
int lcp(int x, int y) //第x和第y个后缀(不是排名)的最长公共前缀
{
if(x == y) return n - x + 1;
x = rk[x], y = rk[y];
if(x > y) swap(x, y);
x ++;
int t = base[y - x + 1];
return min(f[x][t], f[y - (1 << t) + 1][t]);
}
void init()
{
memset(c, 0, sizeof c);
memset(x, 0, sizeof x);
}
struct node
{
int val, cnt;
//val 是与排名前一位的最长公共前缀
//cnt是排名在他后面, 且连续且最长公共前缀大于它的数。
};
node b[N];
int ans[N];
int main()
{
while(scanf("%d", &k) && k)
{
init();
scanf("%s", s + 1);
scanf("%s", a + 1);
int len1 = strlen(s + 1);
int len2 = strlen(a + 1);
s[len1 + 1] = 1;
s[len1 + len2 + 2] = '\0';
for(int i = len1 + 2, j = 1; j <= len2; j ++, i++)
{
s[i] = a[j];
}
n = len1 + len2 + 1;
m = 200;
get_sa();
get_height();
init_rmq();
// for(int i = 1; i <= n; i ++)
// {
// //cout << height[i] << endl;
// for(int j = sa[i]; j <= n; j ++) cout << s[j];
// cout << endl;
// }
long long sum = 0, ans = 0, top = 0, tot = 0, cnt = 0;//tot栈中元素的贡献
for(int i = 2; i <= n; i++)
{
cnt = 0;
if(height[i] < k) top = 0, tot = 0;
else
{
if(sa[i - 1] <= len1) // 如果排名第i - 1是A的前缀
{
cnt ++;
tot += height[i] - k + 1;//记录第i - 1个串的贡献
}
//如果是B的后缀则cnt为零
while(top && b[top - 1].val >= height[i])
{
top --;
tot -= b[top].cnt * (b[top].val - height[i]);//b[top].cnt 在他之后已经出栈的元素
cnt += b[top].cnt;
}
b[top].cnt = cnt;
b[top ++].val = height[i];
if(sa[i] > len1 + 1) ans += tot;
}
}
tot = 0, top = 0;
for(int i = 2; i <= n; i++)
{
cnt = 0;
if(height[i] < k) top = tot = 0;
else
{
if(sa[i - 1] > len1 + 1)
{
cnt ++;
tot += height[i] - k + 1;
}
while(top != 0 && b[top - 1].val >= height[i])
{
top --;
tot -= b[top].cnt * (b[top].val - height[i]);
cnt += b[top].cnt;
}
b[top].val = height[i];
b[top ++].cnt = cnt;
if(sa[i] <= len1) ans += tot;
}
}
cout << ans << endl;
}
return 0;
}