字符串是最直观的,然而真正的字符串算法同阶段而言没一个好理解的。
——学SA第一夜有感
SA的简单介绍和应用
SA的大部分学习内容都是来源于OI Wiki,尽管理解的内容会比专业叙述更易于理解一些,然而还是OI Wiki讲的详细和科学。因此贴一个原链接:后缀数组简介 - OI Wiki
一.简单且清晰地理解SA
这个学习报告的最重要且是唯一的意义,在于让我迅速记住自己到底 n 个晚上到底都记住了些什么…
求后缀数组
几个基本概念简单说一句:在一个字符串当中,后缀是一个字符串到末尾的子串,前缀是一个字符串从开头起的子串。
基础的后缀数组一共只有两部分:
s
a
(
i
)
sa(i)
sa(i) 表示第
i
i
i 小的后缀的编号(显然是其左端点,从1开始),
r
k
(
i
)
rk(i)
rk(i) 表示编号为
i
i
i 的后缀从小到大的排名。
现在的问题就是怎么求出这两个数组,由于SA实质上是一个排名的数组,所以求SA的过程称为后缀排名。
一种比较常用也是相对比较好理解的方法是倍增法。首先设
r
k
w
(
i
)
rk_w(i)
rkw(i) 表示编号为
i
i
i ,长度为
k
k
k 的子串的排名(实际上,代码实现中是直接把
r
k
(
i
)
rk(i)
rk(i) 进行一个类似滚动的操作来完成)。
假设已知
r
k
k
(
i
)
rk_k(i)
rkk(i) ,要想过渡到
r
k
2
k
(
i
)
rk_{2k}(i)
rk2k(i) ,比较容易想到的是这个是由两段子串组成的;假设想要进一步排序,按照字典序来说应该先比较前半段,然后比较后半段,这就区分出了第一和第二关键字。这样就可以完成一次排序,在实际操作中就是在每一次排序后滚动一下
r
k
(
i
)
rk(i)
rk(i) ,然后再进行第二次排序。如果
i
+
2
k
−
1
i+2k-1
i+2k−1 已经超出
n
n
n ,那其实也没有什么问题,毕竟后面多出的部分其实全都是0.
这样一来,倍增算法的复杂度是
O
(
(
排
序
+
n
)
l
o
g
n
)
O((排序+n)logn)
O((排序+n)logn) ,所以应该考虑用一些更快捷的排序方法降低复杂度,这里采取的是计数排序。
计数排序跟桶排序有点像,我感觉这俩的区别基本在于桶排序是分块一样地处理,先装进各自的桶排序,然后合并;计数排序是直接塞进以值为下标的数组然后直接合并到一起,所以两者确实不是很一样。
然而我感觉我刚学OI的时候学的桶排序跟计数排序除了没有前缀和合并真的就是一回事…
在计数排序当中,两个某一级关键字相同的元素经过这一级排序后相对关系不变,因此应该按关键字从低到高反向排序(不是很好理解,建议举个例子自行尝试一下)。
此时的代码如下:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 2e6 + 1;
char s[N];
int k,n,m,sa[N],temp[N],num[N],buc[N],rk[N];
bool check(int x,int y,int z){
return buc[x] == buc[y] && buc[x + z] == buc[y + z];
}
void sort(){//计数排序
int i;
memset(num,0,sizeof(num));
for(i = 1;i <= n;i++) temp[i] = sa[i];
for(i = 1;i <= n;i++) ++num[rk[temp[i] + k]];
for(i = 1;i <= m;i++) num[i] += num[i - 1];
for(i = n;i >= 1;i--) sa[num[rk[temp[i] + k]]--] = temp[i];
//先按rk[i+k]排序
memset(num,0,sizeof(num));
for(i = 1;i <= n;i++) temp[i] = sa[i];
for(i = 1;i <= n;i++) ++num[rk[temp[i]]];
for(i = 1;i <= m;i++) num[i] += num[i - 1];
for(i = n;i >= 1;i--) sa[num[rk[temp[i]]]--] = temp[i];
//再按rk[i]排序
}
int main(){
int i,j;
scanf("%s",s + 1);
n = strlen(s + 1);
m = max(n,300);//m代表值域大小,实际上直接=300通常也可以
for(i = 1; i <= n; ++i) sa[i] = i,rk[i] = s[i];
for(k = 1;k < n;k <<= 1){
sort();
memcpy(buc,rk,sizeof(rk));//“滚动”
for(i = 1,j = 0;i <= n;i++){
if(check(sa[i],sa[i - 1],k)) rk[sa[i]] = j;//去重
else rk[sa[i]] = ++j;
}
}
for(i = 1;i <= n;i++) printf("%d ",sa[i]);
return 0;
}
在理解后缀排序的时候,不要忘记 s a ( i ) sa(i) sa(i) 是待排序数组, r k ( i ) rk(i) rk(i) 则是用于维护它的一个的反函数,否则很容易犯 r k [ s a [ i ] ] rk[sa[i]] rk[sa[i]] 和 s a [ r k [ i ] ] sa[rk[i]] sa[rk[i]] 写串之类的问题。
采取计数排序后,算法的复杂度会达到 O ( n l o g n ) O(nlogn) O(nlogn) 的水平,从理论上来说已经足够快了,但是此时的算法常数会非常大,为了让程序更快一些,有必要采取一些优化。
吐槽一句,洛谷的SA模板甚至用快排都能轻松通过,真的有一点水233
这段代码常数之所以大,很大程度上是由于计数排序的常数大,所以还是要继续优化排序。
首先是这个值域的问题,初始的值域是
m
a
x
(
n
,
300
)
max(n,300)
max(n,300) ,实际上很多时候用不到这么大。每一次更新了
r
k
k
(
i
)
rk_k(i)
rkk(i) 之后,其实下一次排序用到的值域就是这么大,所以可以一定程度上缩小值域。
其次,排序不一定要完全排完,假设某一轮排序后
r
k
k
(
i
)
rk_k(i)
rkk(i) 不存在重复,那么接下来扩展字符串的时候也一定不存在重复,且大小关系也一定不变,这个时候可以直接退出排序。
最后是关键字的问题,在排第二维的时候有很多字符串实际已经超出范围了(也就是
r
k
k
(
i
)
)
rk_k(i))
rkk(i)) 最小),这些可以直接放到最前面,剩下的按照上一次排出的排名(也就是调用上一轮排序后已经有序的
s
a
(
i
)
sa(i)
sa(i) )直接接到后面就可以了。
另外就是一些涉及到C++或者计算机原理的东西,直接写注释里面了。
代码如下(洛谷746ms):
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 2e6 + 1;
char s[N];
int k,n,m,sa[N],temp[N],num[N],buc[N],rk[N],px[N];
#define check(x,y,z) buc[x] == buc[y] && buc[x + z] == buc[y + z]
void sort(){
int i,j;
for(i = n,j = 0;i > n - k;i--) temp[++j] = i;
for(i = 1;i <= n;i++) if(sa[i] > k) temp[++j] = sa[i] - k;
//简化掉第二关键字排序
memset(num,0,sizeof(num));
for(i = 1;i <= n;i++) ++num[px[i] = rk[temp[i]]];
//把[rk[temp[i]]存下来可以加速
for(i = 1;i <= m;i++) num[i] += num[i - 1];
for(i = n;i >= 1;i--) sa[num[px[i]]--] = temp[i];
}
int main(){
int i,j;
scanf("%s",s + 1);
n = strlen(s + 1);
m = 300;
for(i = 1;i <= n;++i) ++num[rk[i] = s[i]];
for(i = 1;i <= m;++i) num[i] += num[i - 1];
for(i = n;i >= 1;--i) sa[num[rk[i]]--] = i;
//需要预先对sa[i]进行一次排序,否则第一次第二关键字排序是错的
for(k = 1;k < n;k <<= 1){
sort();
memcpy(buc,rk,sizeof(rk));
for(i = 1,j = 0;i <= n;i++){
if(check(sa[i],sa[i - 1],k)) rk[sa[i]] = j;//PS
else rk[sa[i]] = ++j;
}
m = j;//更新值域
if(j == n) break;//提前结束排序
}
for(i = 1;i <= n;i++) printf("%d ",sa[i]);
return 0;
}
PS:按照OI Wiki的解释,这个check写成函数比直接把表达式写到括号里面更快。然而我用#define比用函数还快,#define又是文本替换,所以我也搞不懂这个优化怎么回事…
求高度数组
如果说前面的部分只是对 s a ( i ) sa(i) sa(i) r k ( i ) rk(i) rk(i) 性质的简单运用就已经有点混乱了,那这部分才是真的让人完全找不到北。
单纯的后缀数组其实作用并不多,因为很多问题关心的是子串而非后缀。不过实际上任意一个子串都是一个后缀的前缀,所以可以在后缀数组当中研究前缀的问题来扩展SA的作用。
记
l
c
p
(
i
,
j
)
lcp(i,j)
lcp(i,j) 表示
s
a
(
i
)
sa(i)
sa(i),
s
a
(
j
)
sa(j)
sa(j) 的最长公共前缀,定义高度数组
h
t
(
i
)
=
l
c
p
(
s
a
(
i
)
,
s
a
(
i
−
1
)
)
ht(i)=lcp(sa(i),sa(i-1))
ht(i)=lcp(sa(i),sa(i−1)) ,即第
i
i
i 名的后缀与它前一名的后缀的最长公共前缀。
h
t
(
i
)
ht(i)
ht(i) 的一个关键意义在于它具有一定的传递性且相对比较容易计算。
一个关键性质是:
l
c
p
(
s
a
(
i
)
,
s
a
(
j
)
)
=
m
i
n
{
h
t
(
i
+
1..
j
)
}
lcp(sa(i),sa(j))=min\{ht(i+1..j)\}
lcp(sa(i),sa(j))=min{ht(i+1..j)}
因此只要求出
h
t
(
i
)
ht(i)
ht(i) ,就能把求
l
c
p
lcp
lcp 变成一个RMQ问题,也就能比较容易地解决一系列
l
c
p
lcp
lcp 的问题。现在的问题就变成了怎么求出这个数组。
在讨论这个数组的求法之前,很有必要重新强调一下 s a ( i ) sa(i) sa(i) r k ( i ) rk(i) rk(i) 这两个数组的意义。 s a ( i ) sa(i) sa(i) 数组下标是排名,因此其下标的变化是排名的变化; r k ( i ) rk(i) rk(i) 数组下标是编号,因此其下标的变化是编号(即字符串中真正的位置)的变化。这两者下标和权值的变化恰好相反。
如果不搞清楚这俩的关系,接下来的部分会非常非常难以阅读和理解,至少看懂高度数组的求法花了我约半个小时。
对于
h
t
(
i
)
ht(i)
ht(i) 数组有一个引理:
h
t
(
r
k
(
i
)
)
≥
h
t
(
r
k
(
i
−
1
)
)
−
1
ht(rk(i)) \ge ht(rk(i-1))-1
ht(rk(i))≥ht(rk(i−1))−1
证明:显然只有
h
t
(
r
k
(
i
−
1
)
)
>
1
ht(rk(i-1)) > 1
ht(rk(i−1))>1 的部分值得讨论(因为
h
t
ht
ht 具有非负性)。
假设后缀
i
−
1
i-1
i−1 为
a
A
D
aAD
aAD ,后缀
i
i
i 为
A
D
AD
AD (其中
A
A
A 是一段长度为
h
t
(
r
k
(
i
−
1
)
)
−
1
ht(rk(i-1))-1
ht(rk(i−1))−1 的字符串),那么
s
a
(
r
k
(
i
−
1
)
−
1
)
sa(rk(i-1)-1)
sa(rk(i−1)−1) (编号为
i
−
1
i-1
i−1 的后缀的前一名的后缀)如果设为
a
A
B
aAB
aAB ,
s
a
(
r
k
(
i
−
1
)
−
1
)
+
1
sa(rk(i-1)-1)+1
sa(rk(i−1)−1)+1 为
A
B
AB
AB.
同理,
s
a
(
r
k
(
i
)
−
1
)
sa(rk(i)-1)
sa(rk(i)−1) 应该为
A
B
/
A
C
AB/AC
AB/AC,此时
l
c
p
(
s
a
(
r
k
(
i
)
)
,
s
a
(
r
k
(
i
)
−
1
)
)
lcp(sa(rk(i)),sa(rk(i)-1))
lcp(sa(rk(i)),sa(rk(i)−1)) ,即
l
c
p
(
i
,
s
a
(
r
k
(
i
)
−
1
)
)
=
A
X
lcp(i,sa(rk(i)-1)) = AX
lcp(i,sa(rk(i)−1))=AX ,
X
X
X 也可以为空。证毕。
通过这个引理,为了求
h
t
(
i
)
ht(i)
ht(i) ,只要从第
s
a
(
r
k
(
i
)
−
1
)
sa(rk(i)-1)
sa(rk(i)−1) 位开始与从第 $i $ 位开始的子串逐位地比对即可(如果比对成功了,不要忘记引理的-1,也就是需要前移这个比对成功的长度)。考虑到这个比对的指针最多只会增减
n
n
n 次,所以是线性的。
代码如下:
void Getht(){
int i,j;
for(i = 1,j = 0;i <= n;i++){
if(j) --j;
while(s[i + j] == s[sa[rk[i] - 1] + j]) ++j;
ht[rk[i]] = j;
}
ht[1] = 0;//用这个算法算出的ht[1]可能不是0,得手动维护一下
}