后缀数组做法
写给不了解后缀数组的:
后缀数组的做法一般都包括三个数组:
s
a
sa
sa:所有后缀中字典序第
i
i
i 大的是从位置
s
a
[
i
]
sa[i]
sa[i] 开始的后缀;
r
a
n
k
rank
rank:位置
i
i
i 开始的后缀在所有后缀中字典序排第
r
a
n
k
[
i
]
rank[i]
rank[i];
l
c
p
lcp
lcp:高度数组,字典序第
i
i
i 大得缀与字典序第
i
+
1
i+1
i+1 大的后缀的最长公共前缀(longest common prefix)。另外通过建立
l
c
p
lcp
lcp 上的 st表,可以求得任意两个后缀
x
x
x 和
y
y
y 的最长公共前缀:
q
u
e
r
y
_
m
i
n
(
r
a
n
k
[
x
]
,
r
a
n
k
[
y
]
−
1
)
,
这里假设
r
a
n
k
[
x
]
<
r
a
n
k
[
y
]
query\_min(rank[x],rank[y]-1), 这里假设rank[x]<rank[y]
query_min(rank[x],rank[y]−1),这里假设rank[x]<rank[y] 。
大体思路:对于当前字符串 B i B_i Bi 的每个位置 j ( 1 ≤ j ≤ m ) j \ (1 \le j \le m) j (1≤j≤m) , 找到从 j j j 开始的最长的子串,且这个子串在 A A A 中出现过。也就是找到一个最长的长度 l e n len len,满足 B i B_i Bi 的子串 [ j , j + l e n − 1 ] [j,j+len-1] [j,j+len−1] 同时也是 A A A 的一个子串;然后再就是找以位置 j j j 为左端点,长度在 l e n len len 之内所有的区间权值和的最大值,也就是 max ( ∑ k = j l v k ) , ( j ≤ l ≤ j + l e n − 1 ) \max( \sum_{k=j}^{l} v_k ), (j \le l \le j+len-1) max(∑k=jlvk),(j≤l≤j+len−1) ,这部分可以对数组 v v v 的前缀和数组建 s t st st 表,然后查询区间 [ j , j + l e n − 1 ] [j,j+len-1] [j,j+len−1] 的最大值。
关于如何找到 B i B_i Bi 以位置 j ( 1 ≤ j ≤ m ) j \ (1 \le j \le m) j (1≤j≤m) 开头,在 A A A 中出现过的最长子串:
把 A A A 和所有 B i B_i Bi 连在一起,中间用没出现过的字符(例如 ‘$’)分隔,得到的字符串记为 S S S。然后对 S S S 跑后缀数组。在拼接的过程中,维护每个 B i B_i Bi 在 S S S 中的起始下标,和第 i i i 个字符对应于原来哪一个串,后面要用。
跑出来后缀数组后,按字典序遍历所有后缀。借助前面维护的信息,我们可以知道当前后缀对应哪个 B i B_i Bi 或者说对应 A A A; 如果当前后缀是对应某个 B i B_i Bi 的,就找到离它最近的,属于 A A A 串的后缀,求它们之间的 l c p lcp lcp ,这个 l c p lcp lcp 就是我们前面要求的那个 l e n len len。
比如题目样例一:(左边的三列数字分别代表:字典序大小,sa 的值,lcp 的值)
字典序第 18 18 18 大和第 22 22 22 大的后缀都是原属于 A A A 的后缀,第 19 19 19 大的后缀对应于 B 3 B_3 B3 从下标 1 开始的后缀,也就是它本身。 min 18 ≤ i < 19 ( l c p [ i ] ) = 4 , min 19 ≤ i < 22 ( l c p [ i ] ) = 0 \underset{18\le i<19}\min(lcp[i]) = 4,\underset{19\le i<22}\min(lcp[i]) = 0 18≤i<19min(lcp[i])=4,19≤i<22min(lcp[i])=0 ,那么 B 3 B_3 B3 从 1 开始,长度在 4 4 4 以内的子串都在 A A A 中出现过,求得这个最长的公共子串长度后,用 st 表求一个权重前缀和的最大值,更新 B 3 B_3 B3 的答案即可。
有两个需要注意的点:
1、本来求 lcp 的部分我是用先从前到后和从后到前循环一遍,记录每个位置左边第一个和右边第一个属于
A
A
A 的后缀,然后 st 表查询区间 lcp 最小值,但是超时了。后来发现在循环的时候直接维护最小值就行了,不需要构建 st 表。
2、所有
B
i
B_i
Bi 之间可以用 ‘$’ 分隔,但是
A
A
A 和
B
1
B_1
B1 之间最好再换个,不然下面这种数据,跑出来的 lcp 可能处理起来有点麻烦.
最后勉强 700+ms 跑过:
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 1e9+7;
const long long INFq = 1e18+7;
const long long mode = 998244353;
const int MAX_N = 1300100;
char s[MAX_N];
///----- SA-IS template -----
int sa[MAX_N], Rank[MAX_N], lcp[MAX_N];
int str[MAX_N<<1], Type[MAX_N<<1], p[MAX_N], cnt[MAX_N], cur[MAX_N];
#define pushS(x) sa[cur[str[x]]--] = x
#define pushL(x) sa[cur[str[x]]++] = x
#define inducedSort(v) fill_n(sa, n, -1); fill_n(cnt, m, 0); \
for (int i = 0; i < n; i++) cnt[str[i]]++; \
for (int i = 1; i < m; i++) cnt[i] += cnt[i-1]; \
for (int i = 0; i < m; i++) cur[i] = cnt[i]-1; \
for (int i = n1-1; ~i; i--) pushS(v[i]); \
for (int i = 1; i < m; i++) cur[i] = cnt[i-1]; \
for (int i = 0; i < n; i++) if (sa[i] > 0 && Type[sa[i]-1]) pushL(sa[i]-1); \
for (int i = 0; i < m; i++) cur[i] = cnt[i]-1; \
for (int i = n-1; ~i; i--) if (sa[i] > 0 && !Type[sa[i]-1]) pushS(sa[i]-1)
void sais(int n, int m, int *str, int *Type, int *p) {
int n1 = Type[n-1] = 0, ch = Rank[0] = -1, *s1 = str+n;
for (int i = n-2; ~i; i--) Type[i] = str[i] == str[i+1] ? Type[i+1] : str[i] > str[i+1];
for (int i = 1; i < n; i++) Rank[i] = Type[i-1] && !Type[i] ? (p[n1] = i, n1++) : -1;
inducedSort(p);
for (int i = 0, x, y; i < n; i++) if (~(x = Rank[sa[i]])) {
if (ch < 1 || p[x+1] - p[x] != p[y+1] - p[y]) ch++;
else for (int j = p[x], k = p[y]; j <= p[x+1]; j++, k++)
if ((str[j]<<1|Type[j]) != (str[k]<<1|Type[k])) {ch++; break;}
s1[y = x] = ch;
}
if (ch+1 < n1) sais(n1, ch+1, s1, Type+n, p+n1);
else for (int i = 0; i < n1; i++) sa[s1[i]] = i;
for (int i = 0; i < n1; i++) s1[i] = p[sa[i]];
inducedSort(s1);
}
int mapCharToInt(int n) {
int m = *max_element(s, s+n);
fill_n(Rank, m+1, 0);
for (int i = 0; i < n; i++) Rank[s[i]] = 1;
for (int i = 0; i < m; i++) Rank[i+1] += Rank[i];
for (int i = 0; i < n; i++) str[i] = Rank[s[i]] - 1;
return Rank[m];
}
void SuffixArray(int n) {
// s[n] 一定要比 s 中所有字符 ascii 值小, s[n+1] 倒无所谓
s[n] = '!'; s[n+1]='\0';
int m = mapCharToInt(++n);
sais(n, m, str, Type, p);
for (int i = 0; i < n; i++) Rank[sa[i]] = i;
for (int i = 0, h = lcp[0] = 0; i < n-1; i++) {
int j = sa[Rank[i]-1];
while (i+h < n && j+h < n && s[i+h] == s[j+h]) h++;
if (lcp[Rank[i]-1] = h) h--;
}
s[n]='\0';
}
///----- End of SA-IS -----
long long st2[100010][20];
int lg[100010];
long long pref[100010];
void construct_st2(int n) {
for(int i=1;i<=n;i++)st2[i][0] = pref[i];
for(int k=1,len=2; len<=n; len*=2,k++) {
for(int i=1;i+len-1<=n;i++) {
st2[i][k] = max( st2[i][k-1], st2[i+len/2][k-1] );
}
}
}
inline long long query2(int x,int y) {
int k = lg[y-x+1];
return max( st2[x][k], st2[y-(1<<k)+1][k] );
}
int v[100010];
int start_pos[100010];
long long ans[100010];
int Map[MAX_N];
int main() {
ios::sync_with_stdio(false);
lg[1] = 0;
for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1;
int n,m,k;
cin >> n >> m >> k;
cin >> s;
for(int i=1;i<=m;i++) cin >> v[i];
int tot_len = n-1;
for(int i=0;i<n;i++) Map[i] = 0; // Map 用来映射 s[i] 对应于原来那个串,0 就是 A;
for(int i=1;i<=k;i++) {
++ tot_len;
s[tot_len] = '$'; Map[tot_len] = -1; // -1 代表是分隔符;
start_pos[i] = tot_len+1; // 记录开始位置
cin >> ( s + tot_len + 1 );
for(int j=tot_len+1; j<=tot_len+m; j++) Map[j] = i; // 代表 s[j] 原属于 B_i
tot_len += m;
}
s[n] = '#'; // A 和 B_1 之间用 '#' 而非 '$'
++ tot_len;
SuffixArray(tot_len ); // 板子传入的参数是字符串的长度,下标从 0 开始, tot_len 是 '\0' 的位置
// s[tot_len] = '\0';
// cout << "s = " << s << '\n';
// for(int i=0;i<=tot_len;i++) {
// printf("%3d %3d %3d %s\n",i,sa[i],lcp[i],s+sa[i]);
// }
// cout << '\n';
// 构建前缀和
pref[0] = 0;
for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
// 前缀和的区间最大值; 为啥是 st2?因为原本有个(多余的) st 用来求 lcp, 但超时了
construct_st2(m);
int Min = 0;
for(int i=1;i<=tot_len;i++) { // 从左到右遍历一边, 用每个后缀左边第一个属于 A 的后缀更新答案
int j = Map[ sa[i] ]; // sa[i] 代表字典序第 i 大的后缀在原串的起始位置,再用 Map 映射到原来对应的串
if( j == 0 ) {
Min = lcp[i]; // 是 A 的后缀,则重置 Min
}
else {
if( j > 0 && Min > 0 ) { // 对应 B_j 的某个后缀
int index = sa[i] - start_pos[j] + 1; // index 是对应的 B_j 的那个后缀的起始下标
long long Max = query2( index, index + Min - 1 ); // 查询区间 pref 最大值
ans[j] = max( ans[j] , Max - pref[index-1] ); // 更新答案
}
Min = min( Min, lcp[i] );
}
}
Min = 0;
for(int i=tot_len;i>0;i--) { // 从右到左遍历,用每个后缀右边第一个属于 A 的后缀更新答案,几乎一样的
int j = Map[ sa[i] ];
if( j == 0 ) {
Min = lcp[i-1];
}
else {
if( j > 0 && Min > 0 ) {
int index = sa[i] - start_pos[j] + 1;
long long Max = query2( index, index + Min - 1 );
ans[j] = max( ans[j] , Max - pref[index-1] );
}
Min = min( Min, lcp[i-1] );
}
}
for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}
后缀自动机做法
和后缀数组的思路是一样的,不过这里对于字符串 B i B_i Bi 的每个位置 j ( 1 ≤ j ≤ m ) j \ (1 \le j \le m) j (1≤j≤m) ,是找以 j j j 结尾的最长的在 A A A 中出现过的子串,后面查询的也是 [ j − l e n , j − 1 ] [j-len,\ j-1] [j−len, j−1] 之间前缀和的最小值。
怎么找:
对 A A A 建立后缀自动机后,记当前的 B i B_i Bi 为 T T T (这样我能少打一个下标qwq),在 A A A 的自动机上跑匹配,假设 T T T 串第 i − 1 i-1 i−1 的位置在 A A A 的自动机上匹配的最大子串长度为 m a x _ l e n max\_len max_len,对应自动机上的节点为 l a s t _ p o s last\_pos last_pos,那么以 T [ i ] T[i] T[i] 结尾的串肯定是某个以 T [ i − 1 ] T[i-1] T[i−1] 结尾的串后面加上字符 T [ i ] T[i] T[i],我们就从 l a s t _ p o s last\_pos last_pos 开始在 p a r e n t parent parent 树中向上转移,直到遇到第一个存在字符 T [ i ] T[i] T[i] 的出边的节点位置,这个过程中记录 m a x _ l e n max\_len max_len ,最后 +1 就是 i i i 的答案。
嗯,自己写的自己都看不懂写的什么东西。 还是看图吧
假设
A
A
A 为
b
c
d
a
b
c
bcdabc
bcdabc,
T
T
T 为
a
b
c
d
abcd
abcd,开始
l
a
s
t
_
p
o
s
last\_pos
last_pos 设为
1
1
1 ,代表根节点,
m
a
x
_
l
e
n
=
0
max\_len = 0
max_len=0,因为根节点对应的子串为空串,
A
A
A 的自动机长这样子:(每个节点块最后一行{}里的是该节点的
e
n
d
p
o
s
endpos
endpos 集——节点代表的子串在原串中的结束位置;黑色的边是 parent 的边,蓝色带箭头的是自动机的转移边,旁边的字母是对应的出边的类型;黑边上也有字母是因为 parent 的边和自动机的边重了;len 代表当前节点所代表的子串的最大长度)
每个节点对应的子串:
首先是 a
,正好
1
1
1 号节点有 a
的出边,走到 5 号节点,
m
a
x
_
l
e
n
max\_len
max_len++,最大长度为
1
1
1
然后是 b
,
5
5
5 号节点有 b
的出边,走到
6
6
6 号节点,
m
a
x
_
l
e
n
max\_len
max_len++,最大长度为
2
2
2
然后是 c
,
6
6
6 号节点有 c
的出边,走到
7
7
7 号节点,
m
a
x
_
l
e
n
max\_len
max_len++,最大长度为
3
3
3
然后是 d
,
7
7
7 号节点没有 d
的出边,沿
p
a
r
e
n
t
parent
parent 数向上走,走到
3
3
3 号节点有 d
的出边,走到
4
4
4 号节点,
m
a
x
_
l
e
n
=
l
e
n
[
3
]
+
1
=
3
max\_len = len[3]+1 = 3
max_len=len[3]+1=3
#include<iostream>
#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
const int MAX_N = 100010;
int par[MAX_N<<1], sam[MAX_N<<1][26],len[MAX_N<<1];
int last,tot;
void sam_extend(int ch) {
int p = last;
tot++;
int np = last = tot;
len[np] = len[p] + 1;
while( p>0 && sam[p][ch]==0 ){
sam[p][ch] = np;
p = par[p];
}
if( p==0 ){
par[np] = 1;
}
else{
int q = sam[p][ch];
if( len[q] == len[p]+1 )par[np] = q;
else{
tot++;
int nq = tot;
len[nq] = len[p]+1;
par[nq] = par[q];
for(int i=0;i<26;i++)sam[nq][i] = sam[q][i];
par[np] = par[q] = nq;
while( p>0 && sam[p][ch]==q ){
sam[p][ch] = nq;
p = par[p];
}
}
}
}
int last_pos, max_len;
void Go(int ch) {
int p = last_pos;
while( p > 0 && sam[p][ch] == 0 ) {
p = par[p];
max_len = len[p];
}
if( p == 0 ) {
// 如果 1 号根节点都没有 ch 的出边,说明字符 ch 在字符串中不存在
last_pos = 1;
}
else {
int q = sam[p][ch]; // 沿着出边走出去
++ max_len; // 就是当前的最大长度
last_pos = q;
}
}
long long pref[100010];
long long st[MAX_N][20];
int lg[MAX_N];
void construct_st(int n) {
for(int i=0;i<=n;i++)st[i][0] = pref[i];
for(int k=1,len=2; len<=n; len*=2,k++) {
for(int i=0;i+len-1<=n;i++) {
st[i][k] = min( st[i][k-1], st[i+len/2][k-1] );
}
}
}
long long query(int left,int right) {
int k = lg[right-left+1];
return min( st[left][k], st[right-(1<<k)+1][k] );
}
char s[100010], t[100010];
int v[100010];
long long ans[100010];
void print_sam(){
vector<int>edge[20];
for(int i=2;i<=tot;i++)edge[par[i]].push_back(i);
for(int i=1;i<=tot;i++) {
printf("child %d :",i); for(int u : edge[i])printf(" %d",u); printf("\n");
}
for(int i=1;i<=tot;i++) {
printf("sam %d :\n",i);
for(int j=0;j<26;j++) {
if( sam[i][j] > 0 ) {
printf(" %c -> %d\n",'a'+j,sam[i][j]);
}
}
}
}
int main() {
// cin.tie(nullptr) -> sync_with_stdio(false);
lg[1] = 0;
for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1;
int n,m,k;
cin >> n >> m >> k;
cin >> (s+1);
for(int i=1;i<=m;i++) cin >> v[i];
last = tot = 1;
for(int i=1;i<=n;i++) sam_extend(s[i] - 'a');
pref[0] = 0;
for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
construct_st(m);
for(int j=1; j<=k; j++) {
cin >> (t+1);
last_pos = 1;
max_len = 0;
for(int i=1;i<=m;i++) {
Go(t[i] - 'a'); // 在 parent 树上沿着 last_pos 向上找到第一个有出边 t[i] 的节点
if( max_len > 0 )
ans[j] = max( ans[j], pref[i] - query(i-max_len,i) );
}
}
for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}
如果没学过后缀自动机:
沿着 p a r e n t parent parent 树向下走,相当于在左边添加字符,而越长的子串在原串中的的出现位置相对更少。为什么说要 “从 l a s t _ p o s last\_pos last_pos 开始在 p a r e n t parent parent 树中向上转移,直到遇到第一个存在字符 T [ i ] T[i] T[i] 的出边的节点位置”,因为向上走,相当于不断去掉左边的字符,越短的子串在原串的出现的位置相对更多,更“可能”会遇到一个后面跟着一个字符 T [ i ] T[i] T[i] 的位置。
而沿着 s a m sam sam 的出边转移,相当于在子串的后面添加字符。