因为刚学后缀自动机,似懂非懂,所以这么琐碎的东西也来水博客。
题意
1.相同子串算一个,求一个字符串中字典序第k大的子串。
2.多个相同子串算多个,求一个字符串中字典序第k大的子串。
思路
SAM基本用法之一,建出SAM然后DP。
首先我们令 f [ u ] f[u] f[u]表示从SAM上的 u u u节点出发能走出多少不同子串(包括空串)。
那么我们可以考虑一下一个节点 u u u表示的意义。
- u u u包含了一些长度连续的子串,这些子串在一些 e n d p o s endpos endpos出现。
然后再考虑一下一条边 c h [ u ] [ x ] = v ch[u][x] = v ch[u][x]=v的意义。
- u u u的有一些 e n d p o s endpos endpos位置后面是 x x x字符(但不一定所有 e n d p o s endpos endpos后面都是 x x x)。
- ∀ s t ∈ u , s t + x ∈ v \forall st \in u, st+x \in v ∀st∈u,st+x∈v,但是 v v v可能有其他入边,即 v v v还可能存在其他子串。
- l e n ( v ) ≥ l e n ( u ) + 1 len(v) \ge len(u)+1 len(v)≥len(u)+1, s i z e _ o f _ e n d p o s ( v ) = ∑ r i ∈ e n d p o s ( u ) [ s [ r i + 1 ] = x ] size\_of\_endpos(v)=\sum_{r_i \in endpos(u)} [s[r_i+1]=x] size_of_endpos(v)=∑ri∈endpos(u)[s[ri+1]=x],这两条可以由第2条推得。
然后我们看题目的两种问法,分别求解:
-
f
[
u
]
=
∑
v
=
c
h
[
u
]
[
i
]
f
[
v
]
+
1
f[u]=\sum_{v=ch[u][i]} f[v]+1
f[u]=∑v=ch[u][i]f[v]+1,意思是:
(随便YY)u u u出发的不同子串包含一个空串和加上至少一个字符之后能得到的字符串,递推求解。 - f [ u ] = ∑ v = c h [ u ] [ i ] f [ v ] + s i z e _ o f _ e n d p o s ( u ) f[u]=\sum_{v=ch[u][i]} f[v]+size\_of\_endpos(u) f[u]=∑v=ch[u][i]f[v]+size_of_endpos(u),因为多个不同的串要算不同个数,所以空串不能只算一个了,而是要在每个位置都算一个,但是儿子们的贡献并不用修改,因为 f [ v ] f[v] f[v]中多个不同子串也算了多次。
因为上面边的第3个性质,所以我们可以按 l e n ( u ) len(u) len(u)从大到小扫每个节点,保证扫到 u u u时所有 c h [ u ] [ i ] ch[u][i] ch[u][i]都已经更新过。
DP完了之后找k大就很简单了, O ( l e n ( a n s ) ∗ 26 ) O(len(ans)*26) O(len(ans)∗26)。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 5e5+10, PN = N<<1, S = 26;
int siz[PN];
namespace SAM
{
int fa[PN], ch[PN][S], len[PN], cnt, last;
inline void reset(){
last = cnt = 1;
}
inline void copy(int x, int y){
for (int i = 0; i < S; ++ i)
ch[y][i] = ch[x][i];
fa[y] = fa[x];
}
inline void extend(int x){
x -= 'a';
int p = last, np = ++cnt;
last = np;
len[np] = len[p]+1; siz[np] = 1;
while (p && !ch[p][x]) ch[p][x] = np, p = fa[p];
if (!p) fa[np] = 1;
else{
int q = ch[p][x];
if (len[q] == len[p]+1) fa[np] = q;
else{
int nq = ++cnt;
copy(q, nq);
len[nq] = len[p]+1; siz[nq] = 0;
fa[q] = fa[np] = nq;
while (p && ch[p][x] == q) ch[p][x] = nq, p = fa[p];
}
}
}
}
using namespace SAM;
int n, t, k, rk[PN], bin[N], f[PN];
char s[N];
void get_rk()
{
for (int i = 1; i <= cnt; ++ i) ++bin[len[i]];
for (int i = 1; i <= n; ++ i) bin[i] += bin[i-1];
for (int i = 1; i <= cnt; ++ i) rk[bin[len[i]]--] = i;
}
void solve1()
{
for (int i = cnt; i >= 1; -- i){
int x = rk[i];
siz[fa[x]] += siz[x];
}
for (int i = cnt; i >= 1; -- i){
int x = rk[i];
f[x] = siz[x];
for (int j = 0; j < S; ++ j)
if (ch[x][j])
f[x] += f[ch[x][j]];
}
}
void solve2()
{
for (int i = 1; i <= cnt; ++ i) siz[i] = 1;
for (int i = cnt; i >= 1; -- i){
int x = rk[i];
f[x] = 1;
for (int j = 0; j < S; ++ j)
if (ch[x][j])
f[x] += f[ch[x][j]];
}
}
void print()
{
k += siz[1];
if (f[1] < k) return (void)(puts("-1"));
int u = 1;
while (k){
if (k <= siz[u]) break;
k -= siz[u];
for (int i = 0; i < S; ++ i)
if (ch[u][i]){
if (f[ch[u][i]] < k) k -= f[ch[u][i]];
else{
putchar(i+'a');
u = ch[u][i];
break;
}
}
}
puts("");
}
int main()
{
scanf("%s%d%d", s, &t, &k);
n = strlen(s);
reset();
for (int i = 0; i < n; ++ i)
extend(s[i]);
get_rk();
if (t) solve1();
else solve2();
print();
return 0;
}