题目
分析
统计以
S
i
S_i
Si开头的形如
AA
\text{AA}
AA的子串的数量,存入
L
[
i
]
L[i]
L[i];统计以
S
i
S_i
Si结尾的形如
AA
\text{AA}
AA的子串,存入
R
[
i
]
R[i]
R[i]。于是把可以把它们拼起来,答案就是
∑
i
=
2
n
(
L
[
i
]
×
R
[
i
−
1
]
)
\sum \limits_{i = 2}^{n} (L[i] \times R[i - 1])
i=2∑n(L[i]×R[i−1])。
L
L
L与
R
R
R数组的处理,暴力枚举+哈希判断相等是
O
(
n
2
)
O(n^2)
O(n2)的,考虑优化这个东西。
我们枚举
A
\text{A}
A的长度
l
l
l,那么一个
A
\text{A}
A在
S
S
S中会经过且仅经过一个
S
k
⋅
l
S_{k \cdot l}
Sk⋅l,如图所示,蓝点是
S
k
⋅
l
S_{k \cdot l}
Sk⋅l,可见任何一个长度为
l
l
l的子串必然经过一个蓝点。
那么我们把这个串看成左右两端,即以
S
k
⋅
l
S_{k \cdot l}
Sk⋅l开始的后缀(下图中橙色示意的范围)和以
S
k
⋅
l
S_{k \cdot l}
Sk⋅l开始的前缀(下图中绿色示意的范围),这两个前后缀在
S
k
⋅
l
S_{k \cdot l}
Sk⋅l处重合。
不妨假设这个串是某个
AA
\text{AA}
AA的子串的前一个
A
\text{A}
A,那么它后面紧接着一个跟它一模一样的:
即橙色(两个后缀)和绿色(两个前缀)分别相等。
发现了,我们只需要找到以
S
k
⋅
l
S_{k \cdot l}
Sk⋅l与
S
(
k
+
1
)
⋅
l
S_{(k + 1) \cdot l}
S(k+1)⋅l结尾的最长公共后缀(LCS),和以
S
k
⋅
l
S_{k \cdot l}
Sk⋅l与
S
(
k
+
1
)
⋅
l
S_{(k + 1) \cdot l}
S(k+1)⋅l开头的最长公共前缀(LCP),这两个二分+哈希
O
(
n
log
n
)
O(n \log n)
O(nlogn)即可找到。
找到了过后,看下图(下图的
l
=
7
l = 7
l=7,且只是截取了
S
S
S中的一部分),假设橙色标记的是LCS,绿色标记的是LCP,那么红色标记三对子串都是形如
AA
\text{AA}
AA的:
这个时候我们就左边的三个端点(灰色)的
R
[
i
]
R[i]
R[i]全部加一,右边的三个端点(灰色)的
L
[
i
]
L[i]
L[i]全部加一即可,只有区间加法,差分一下即可 (当然线段树也可以) 。
给不明白差分的小伙伴;
令 L ′ [ i ] = L [ i ] − L [ i − 1 ] L'[i] = L[i] - L[i - 1] L′[i]=L[i]−L[i−1],那么我们对 L ′ [ i ] L'[i] L′[i]进行操作,最后可以通过 L [ i ] = ∑ j = 1 i L ′ [ j ] L[i] = \sum \limits_{j = 1}^{i} L'[j] L[i]=j=1∑iL′[j],来还原 L [ i ] L[i] L[i]。
发现 L [ i ] L[i] L[i]其实是 L ′ [ i ] L'[i] L′[i]的前缀和数组,那么 L [ i ] L[i] L[i]的区间加法( [ l , r ] [l, r] [l,r]上加 d d d),在 L ′ [ i ] L'[i] L′[i]上只用改两个点: L ′ [ l ] + = d L'[l] += d L′[l]+=d, L ′ [ r + 1 ] − = d L'[r + 1] -= d L′[r+1]−=d,这样一来,想想算前缀和的过程, [ l , r ] [l, r] [l,r]这一段全部都多了 d d d。
总时间复杂度
O
(
(
n
1
+
n
2
+
⋯
+
n
n
)
log
n
)
=
O
(
n
log
2
n
)
O\left(\left(\dfrac{n}{1}+\dfrac{n}{2}+\cdots+\dfrac{n}{n}\right)\log n\right)=O(n\log^2 n)
O((1n+2n+⋯+nn)logn)=O(nlog2n)。(用SAM/SA可以少个log?)
代码
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
int Read() {
int x = 0; bool f = false; char c = getchar();
while (c < '0' || c > '9')
f |= c =='-', c = getchar();
while (c >= '0' && c <= '9')
x = x * 10 + (c ^ 48), c = getchar();
return f ? -x : x;
}
typedef long long LL;
const int MAXN = 30000;
const int PRIME = 233;
const int MOD = 1000000009;
int N;
char S[MAXN + 5];
int L[MAXN + 5], R[MAXN + 5];
int Hash[MAXN + 5], Pow[MAXN + 5];
int Key(int lft, int rgt) {
return (Hash[rgt] - (LL)Hash[lft - 1] * Pow[rgt - lft + 1] % MOD + MOD) % MOD;
}
int GetLCS(int i, int j) {
int lft = 0, rgt = std::min(j - i, i) + 1; // 注意上界不要超了,否则会访问到不该访问的地方
while (lft + 1 < rgt) {
int mid = (lft + rgt) >> 1;
if (Key(i - mid + 1, i) == Key(j - mid + 1, j))
lft = mid;
else
rgt = mid;
}
return lft;
}
int GetLCP(int i, int j) {
int lft = 0, rgt = std::min(j - i, N - j + 1) + 1; // 这里也是
while (lft + 1 < rgt) {
int mid = (lft + rgt) >> 1;
if (Key(i, i + mid - 1) == Key(j, j + mid - 1))
lft = mid;
else
rgt = mid;
}
return lft;
}
int main() {
Pow[0] = 1;
int T = Read();
while (T--) {
scanf("%s", S + 1);
N = strlen(S + 1);
for (int i = 1; i <= N; i++) {
L[i] = R[i] = 0;
Pow[i] = (LL)Pow[i - 1] * PRIME % MOD;
Hash[i] = ((LL)Hash[i - 1] * PRIME + (S[i] - 'a')) % MOD;
}
for (int len = 1; 2 * len <= N; len++) {
for (int i = 1; i + len <= N; i += len) {
int lcs = GetLCS(i, i + len), lcp = GetLCP(i, i + len);
if (lcs + lcp - 1 >= len) {
L[i - lcs + 1]++, L[i + lcp - len + 1]--;
R[i - lcs + 2 * len]++, R[i + lcp + len]--; // 这四个点自己参照图算一下就能找到
}
}
}
for (int i = 1; i <= N; i++)
L[i] += L[i - 1], R[i] += R[i - 1]; // 由差分数组还原
long long Ans = 0;
for (int i = 2; i <= N; i++)
Ans += (long long)L[i] * R[i - 1];
printf("%lld\n", Ans);
}
}