题意
传送门 POJ 3415
题解
子串是原串中连续的一段,也可以定义为前缀的后缀或后缀的前缀。
统计分别属于
A
,
B
A,B
A,B 的不小于
K
K
K 的子串个数,那么将
A
,
B
A,B
A,B 用一个不属于这两个串的字符拼接起来(避免拼接位置对结果产生影响),构造后缀数组以及高度数组(
l
c
p
[
i
]
lcp[i]
lcp[i] 为
a
[
i
]
,
s
a
[
i
+
1
]
a[i],sa[i+1]
a[i],sa[i+1] 的最长公共前缀)。设任意两个后缀在后缀数组中的索引分别为
i
,
j
(
i
<
j
)
i,j(i<j)
i,j(i<j),此时可以快速地求解它们的最长公共前缀,即
m
i
n
{
l
c
p
[
i
]
,
l
c
p
[
i
+
1
]
,
…
,
l
c
p
[
j
−
1
]
}
min\{lcp[i],lcp[i+1],\dots,lcp[j-1]\}
min{lcp[i],lcp[i+1],…,lcp[j−1]} 那么答案为
∑
i
=
0
n
∑
j
=
i
+
1
n
m
a
x
(
0
,
m
i
n
{
l
c
p
[
i
]
,
l
c
p
[
i
+
1
]
,
…
,
l
c
p
[
j
−
1
]
}
)
,
s
a
[
i
]
∈
A
,
s
a
[
j
]
∈
B
\sum\limits_{i=0}^{n}\sum\limits_{j=i+1}^{n}max\big(0,min\{lcp[i],lcp[i+1],\dots,lcp[j-1]\}\big),sa[i]\in A,sa[j]\in B
i=0∑nj=i+1∑nmax(0,min{lcp[i],lcp[i+1],…,lcp[j−1]}),sa[i]∈A,sa[j]∈B
后缀数组 + 单调栈
m i n min min 函数有随着区间的扩大值单调不增的性质,那么从左向右扫描高度数组,维护一个单调递增栈,统计分别属于 A , B A,B A,B 的后缀与当前右界代表的后缀在单调递增的最长公共前缀中的数量,同时统计当前右界可以与左边的后缀构成的三元组数量。
#include <algorithm>
#include <iostream>
#include <string>
using namespace std;
#define maxn 100005
typedef long long ll;
string A, B;
int n, k, na, limit, rnk[maxn << 1], tmp[maxn << 1], sa[maxn << 1], lcp[maxn << 1];
bool cmp_sa(int i, int j)
{
if (rnk[i] != rnk[j])
return rnk[i] < rnk[j];
int ri = i + k <= n ? rnk[i + k] : -1;
int rj = j + k <= n ? rnk[j + k] : -1;
return ri < rj;
}
void construct_sa(string &s, int *sa)
{
for (int i = 0; i <= n; ++i)
{
sa[i] = i;
rnk[i] = i < n ? s[i] : -1;
}
for (k = 1; k < n; k <<= 1)
{
sort(sa, sa + n + 1, cmp_sa);
tmp[sa[0]] = 0;
for (int i = 1; i <= n; ++i)
tmp[sa[i]] = tmp[sa[i - 1]] + (cmp_sa(sa[i - 1], sa[i]) ? 1 : 0);
memcpy(rnk, tmp, sizeof(int) * (n + 1));
}
}
void construct_lcp(string &s, int *sa, int *lcp)
{
for (int i = 0; i <= n; ++i)
rnk[sa[i]] = i;
int h = 0;
for (int i = 0; i < n; ++i)
{
int j = sa[rnk[i] - 1];
if (h > 0)
--h;
for (; i + h < n && j + h < n; ++h)
if (s[i + h] != s[j + h])
break;
lcp[rnk[i] - 1] = h;
}
}
struct node
{
int sz[2], h;
} st[maxn << 1];
ll solve()
{
ll res = 0, cnt[2];
int top;
for (int i = 0; i < n; ++i)
{
if (lcp[i] < limit)
top = 0, cnt[0] = cnt[1] = 0;
else
{
int id = sa[i] < na, sz[2] = {0}, h = lcp[i] - limit + 1;
++sz[id];
while (top > 0 && h <= st[top - 1].h)
{
--top;
for (int j = 0; j < 2; ++j)
{
cnt[j] -= (ll)st[top].sz[j] * st[top].h;
sz[j] += st[top].sz[j];
}
}
st[top].sz[0] = sz[0], st[top].sz[1] = sz[1], st[top].h = h;
++top;
for (int j = 0; j < 2; ++j)
cnt[j] += (ll)sz[j] * h;
int id2 = sa[i + 1] < na;
res += cnt[id2 ^ 1];
}
}
return res;
}
int main()
{
while (cin >> limit && limit)
{
cin >> A >> B;
na = A.length();
A += '$' + B;
n = A.length();
construct_sa(A, sa);
construct_lcp(A, sa, lcp);
cout << solve() << endl;
}
return 0;
}
后缀数组 + 并查集
由于后缀数组的有序性,最长相同公共前缀不小于 K K K 的后缀在后缀数组中是处于连续位置的,那么可以依次扫描求出这样的连续位置。但对于任意一对后缀,需要 R M Q RMQ RMQ 求解最长公共前缀,直接枚举复杂度 O ( N 2 ) O(N^2) O(N2) 难以胜任,考虑从有最大值的最长公共前缀的后缀对 ( i , j ) (i,j) (i,j) 开始统计,那么已统计过的后缀对也必然满足小于这个值的公共前缀。
运用高度数组以并查集维护位于属于同一个连续位置的后缀,此时就可以只进行一次枚举高度的处理,从大到小依次统计有这样长度的公共前缀的后缀对,在合并操作时统计新增的后缀对数。
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
#define maxn 100005
typedef long long ll;
string A, B;
int n, k, na, limit, rnk[maxn << 1], tmp[maxn << 1], sa[maxn << 1], lcp[maxn << 1];
bool cmp_sa(int i, int j)
{
if (rnk[i] != rnk[j])
return rnk[i] < rnk[j];
int ri = i + k <= n ? rnk[i + k] : -1;
int rj = j + k <= n ? rnk[j + k] : -1;
return ri < rj;
}
void construct_sa(string &s, int *sa)
{
for (int i = 0; i <= n; ++i)
{
sa[i] = i;
rnk[i] = i < n ? s[i] : -1;
}
for (k = 1; k < n; k <<= 1)
{
sort(sa, sa + n + 1, cmp_sa);
tmp[sa[0]] = 0;
for (int i = 1; i <= n; ++i)
tmp[sa[i]] = tmp[sa[i - 1]] + (cmp_sa(sa[i - 1], sa[i]) ? 1 : 0);
memcpy(rnk, tmp, sizeof(int) * (n + 1));
}
}
void construct_lcp(string &s, int *sa, int *lcp)
{
for (int i = 0; i <= n; ++i)
rnk[sa[i]] = i;
int h = 0;
for (int i = 0; i < n; ++i)
{
int j = sa[rnk[i] - 1];
if (h > 0)
--h;
for (; i + h < n && j + h < n; ++h)
if (s[i + h] != s[j + h])
break;
lcp[rnk[i] - 1] = h;
}
}
typedef pair<int, int> P;
int par[maxn << 1], rk[maxn << 1], cnt[maxn << 1][2];
vector<P> hs[maxn];
int find(int x)
{
return par[x] == x ? x : (par[x] = find(par[x]));
}
ll unite(int x, int y)
{
x = find(x), y = find(y);
if (rk[x] > rk[y])
swap(x, y);
if (rk[y] == rk[x])
++rk[y];
ll res = (ll)cnt[x][0] * cnt[y][1] + (ll)cnt[x][1] * cnt[y][0];
par[x] = y, cnt[y][0] += cnt[x][0], cnt[y][1] += cnt[x][1];
return res;
}
ll solve()
{
for (int i = 0; i <= n; ++i)
{
par[i] = i, rk[i] = 0;
i < na ? (cnt[i][0] = 1, cnt[i][1] = 0) : (cnt[i][1] = 1, cnt[i][0] = 0);
}
int maxh = max(na, n - 1 - na);
for (int i = limit; i <= maxh; ++i)
hs[i].clear();
for (int i = 0; i < n; ++i)
{
if (lcp[i] >= limit)
hs[lcp[i]].push_back(P(sa[i], sa[i + 1]));
}
ll res = 0, sum = 0;
for (int i = maxh; i >= limit; --i)
{
for (int j = 0; j < (int)hs[i].size(); ++j)
{
sum += unite(hs[i][j].first, hs[i][j].second);
}
res += sum;
}
return res;
}
int main()
{
while (cin >> limit && limit)
{
cin >> A >> B;
na = A.length();
A += '$' + B;
n = A.length();
construct_sa(A, sa);
construct_lcp(A, sa, lcp);
cout << solve() << endl;
}
return 0;
}