题目描述:给定一个长度为
n
n
n的字符串,询问所有长度小于等于
k
k
k的子串中相等的对数。(答案对
998244353
998244353
998244353取模)
对数即
s
i
,
s
j
s_i,s_j
si,sj是
S
S
S的字串,满足
l
e
n
(
s
i
)
,
l
e
n
(
s
j
)
<
=
k
len(s_i),len(s_j)<=k
len(si),len(sj)<=k且
s
i
=
s
j
s_i=s_j
si=sj的数量
如果考场上仔细看一下题目 q m r a s qmras qmras,倒过来就是 s a r m q sarmq sarmq,就知道本题使用后缀数组解决
强烈推荐一篇后缀数组的博客(自为风月马前卒)
回到题目,知道了后缀数组,这道题就好办了
设
d
p
[
i
]
dp[i]
dp[i]代表第
i
i
i个后缀对答案的贡献
按排名枚举每一个后缀,
d
p
[
i
]
=
d
p
[
j
]
+
m
i
n
(
k
,
h
e
i
g
h
t
[
i
]
)
∗
(
i
−
j
)
dp[i] = dp[j]+min(k,height[i])*(i-j)
dp[i]=dp[j]+min(k,height[i])∗(i−j)
j
j
j是在
i
i
i左边第一个
h
e
i
g
h
t
[
j
]
<
h
e
i
g
h
t
[
i
]
height[j]<height[i]
height[j]<height[i]的数
最终答案为所有
d
p
dp
dp值的和
至于查找可以维护一个单调递增的栈
那为什么可以这样求
d
p
[
i
]
dp[i]
dp[i]呢?
首先对于
j
j
j右边的后缀,与
i
i
i的最长公共前缀的长度为
h
e
i
g
h
t
[
i
]
height[i]
height[i],等价于存在那么多个长度分别为
1
1
1 ~
h
e
i
g
h
t
[
i
]
\ height[i]
height[i]的子串与
i
i
i的前缀相等(每一个后缀的前缀,对应着一个唯一的子串)
因为
L
C
P
(
s
a
[
i
]
,
s
a
[
j
]
)
=
m
i
n
(
h
e
i
g
h
t
[
k
]
)
(
i
<
k
<
=
j
)
LCP(sa[i],sa[j])=min(height[k])\quad(i<k<=j)
LCP(sa[i],sa[j])=min(height[k])(i<k<=j)
对于后缀
j
j
j以及其左边的后缀,和
i
i
i最长公共前缀小于
h
e
i
g
h
t
[
i
]
height[i]
height[i],因此可以看作是递归处理
j
j
j
L C P ( s a [ k ] , s a [ j ] ) LCP(sa[k], sa[j]) LCP(sa[k],sa[j])与 L C P ( s a [ k ] , s a [ i ] ) LCP(sa[k],sa[i]) LCP(sa[k],sa[i]) ( k < j ) \quad (k<j) (k<j)是一一对应的(可以自行思考,需要深刻理解后缀数组)
C o d e Code Code
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5, Mod = 998244353;
int rk[MAXN + 10], sa[MAXN + 10], id[MAXN + 10], buck[MAXN + 10], height[MAXN + 10];
int st[MAXN + 10], top = 0;
long long dp[MAXN + 10];
char str[MAXN + 10];
void get_sa(int);
void Rsort(int, int);
void get_height(int);
/*
void Debug(int i){
printf("%d\n", (int)(lower_bound(st + 1, st + top + 1, height[i]) - st - 1));
return;
}
*/
int main(){
freopen ("qmras.in","r",stdin);
freopen ("qmras.out","w",stdout);
int n, k, len;
long long ans = 0;
scanf("%d%d", &n, &k);
for (register int i = 1; i <= n; ++i)
while (str[i] < 'a' || str[i] > 'z') str[i] = getchar();
get_sa(n);
get_height(n);
for (register int i = 1; i <= n; ++i){
while (top && height[st[top]] >= height[i]) --top;
if (!top) len = i;
else len = i - st[top];
dp[i] = (dp[i - len] + len * min(height[i], k)) % Mod;
st[++top] = i;
}
for (register int i = 1; i <= n; ++i)
ans = (ans + dp[i]) % Mod;
printf("%lld\n", ans);
return 0;
}
void Rsort(int n, int m){
for (register int i = 1; i <= m; ++i) buck[i] = 0;
for (register int i = 1; i <= n; ++i) ++buck[rk[i]];
for (register int i = 1; i <= m; ++i) buck[i] += buck[i - 1];
for (register int i = n; i >= 1; --i) sa[ buck[rk[id[i]]]-- ] = id[i];
}
void get_sa(int n){
int m = 127, p = 0;
for (register int i = 1; i <= n; ++i){
rk[i] = (int)str[i];
id[i] = i;
}
Rsort(n, m);
for (register int len = 1; p < n; m = p, len <<= 1){
p = 0;
for (register int i = 1; i <= len; ++i) id[++p] = n - len + i;
for (register int i = 1; i <= n; ++i)
if (sa[i] > len) id[++p] = sa[i] - len;
Rsort(n, m);
std::swap(rk, id);
rk[sa[1]] = p = 1;
for (register int i = 2; i <= n; ++i)
rk[sa[i]] = (id[sa[i]] == id[sa[i - 1]] && id[sa[i] + len] == id[sa[i - 1] + len])? p : ++p;
}
/*
for (register int i = 1; i <= n; ++i)
printf("%d ", sa[i]);
*/
}
void get_height(int n){
int k = 0;
for (register int i = 1; i <= n; ++i){
if (k) --k;
int j = sa[rk[i] - 1];
while (str[i + k] == str[j + k]) ++k;
height[rk[i]] = k;
}
}