看完题解:看不懂;
看 AC 代码:!!??这样的吗?就这么简单?但为啥不会超时啊?
满足题目要求的两个子串 A , B A,B A,B 必然满足类似 A = x p 99...9 A=xp99...9 A=xp99...9, B = x ( p + 1 ) 00...0 B=x(p+1)00...0 B=x(p+1)00...0 的形式,就是三部分,前面一部分相同的前缀 x x x, 中间两个相差为 1 1 1 的数字 p p p 和 p + 1 p+1 p+1,后面再是一部分长度相同的 99...9 99...9 99...9 (跟在 p p p 后面) 和 00...0 00...0 00...0 (跟在 p + 1 p+1 p+1 后面)。另外前面的前缀 x x x 和后缀 999 / 000 999/000 999/000 都可以为空, x x x 可以包含前导 0。
考虑用 S A M SAM SAM 求解这个问题,首先前面一部分相同的前缀 x x x,后缀自动机上随便找一个节点都行,假设我们当前找的节点是 u u u。然后 A A A 后面要跟一个 p p p, B B B 后面要跟一个 p + 1 p+1 p+1,那就是沿着 S A M SAM SAM 上节点 u u u 的出边 p p p 和出边 p + 1 p+1 p+1 走出去,分别到达 x x x 和 y y y,(毕竟 S A M SAM SAM 沿着 p a r e n t parent parent 树向下/向上走是在前面增加/减少一个字符,沿着 sam 自动机的出边走是向后面增加字符 ),如果 x x x 和 y y y 都存在的话,那我就找到了一对 A = x p A=xp A=xp, B = x ( p + 1 ) B=x(p+1) B=x(p+1) 的串,更新答案 a n s = a n s + ( l e n [ u ] − l e n [ p a r [ u ] ] ) ∗ s i z [ x ] ∗ s i z [ y ] ans = ans \ + \ (len[u] - len[par[u]]) \ * \ siz[x] \ * \ siz[y] ans=ans + (len[u]−len[par[u]]) ∗ siz[x] ∗ siz[y] ,其中 l e n [ u ] − l e n [ p a r [ u ] ] len[u] - len[par[u]] len[u]−len[par[u]] 代表 x x x 的不同长度,因为节点 u u u 代表的串(前缀 x x x)不止一个, s i z [ x ] siz[x] siz[x] 代表结点 x x x 的 e n d p o s endpos endpos 集大小,也就是节点 x x x 代表的子串 x p xp xp 在原串中有多少不同的位置,(顺带一提前面的 l e n [ u ] − l e n [ p a r [ u ] ] len[u] - len[par[u]] len[u]−len[par[u]] 也可以理解为是 x p xp xp 的不同长度 )。
加入对于
567568
567568
567568,
u
u
u 就是节点 3,沿着出边 7 和 8 分别走到节点
4
4
4 和节点
7
7
7,此时答案贡献
+
2
+2
+2,分别是
67
/
68
67/68
67/68 和
567
/
568
567/568
567/568.
(Ps; 黑色的是 parent 树的边,蓝色的带箭头的边和数字是 sam 的边,黑边上有数字是因为 parent 边和 sam 边重了)
那么,还需要统计后面那一串 999 / 000 999/000 999/000,这个直接看代码或许更直接:
int x = sam[u][i];
int y = sam[u][i+1];
while( x > 0 && y > 0 ) {
ans += 1ll * siz[x] * siz[y] * (len[u] - len[par[u]]);
x = sam[x][9];
y = sam[y][0];
}
蒽,相信你现在和我当初看这段代码时一样的困惑,这真的不会超时么
后缀自动机的节点数不超过 2 n − 1 2n-1 2n−1
做法是:先统计每个节点的 s i z siz siz,然后遍历每个节点,枚举所有的 p p p,沿着当前节点的 p p p 和 p + 1 p+1 p+1 节点走出去分别代表 A 、 B A、B A、B 串,再一直沿着 9 、 0 9、0 9、0 的出边走到头,统计路径上所有节点的答案。
问题在于对于不同的节点 u u u, “一直沿着 9 、 0 9、0 9、0 的出边走到头” 会不会遇到相同的节点?
假设我先后访问了节点 u u u 、 v v v,如果 u u u 、 v v v 代表的子串不同(其中一个不是另一个的后缀),那后面 “一直沿着 9 、 0 9、0 9、0 的出边走到头” 肯定是不会走到相同的节点的,因为这里沿着 S A M SAM SAM 出边走到的节点肯定能代表 x p 99...9 xp99...9 xp99...9 这个子串,如果前面的那个 x x x 不同,那肯定是两个不同的节点。
而如果
v
v
v 代表的子串是
u
u
u 的一个后缀呢?也就是说在
p
a
r
e
n
t
parent
parent 树上,
v
v
v 是
u
u
u 的祖先,比如:(假设这里 “…” 的地方都是
1
1
1,不会有别的
689
689
689 出现)
我之前访问了子串
56
56
56 代表的节点
u
u
u,现在又访问了子串
356
356
356 代表的节点
v
v
v,同样沿着
8
8
8 和
9
9
9 的出边,后面再 “一直沿着
9
、
0
9、0
9、0 的出边走到头”,这种时候它应该也是走不到相同的节点的,不过走到的节点好像也是存在
p
a
r
e
n
t
parent
parent 树的父子关系的,因为不管走到哪,
356899...
356899...
356899... 的出现位置肯定比
56899...
56899...
56899... 的少,那么
e
n
d
p
o
s
endpos
endpos 集就肯定不一样。所以每个节点至多被访问两次,一次是枚举这个点作为
p
p
p、
p
+
1
p+1
p+1 的起点,一个是被某个点沿着
9
/
0
9/0
9/0 的出边走到这个点。
个人代码:
const int MAX_N = 1000005;
char s[MAX_N];
int par[MAX_N<<1], sam[MAX_N<<1][10],len[MAX_N<<1];
int siz[2 * MAX_N];
int last,tot;
void sam_extend(int ch)
{
int p = last;
tot++;
int np = last = tot;
len[np] = len[p] + 1;
siz[np] = 1;
while( p>0 && sam[p][ch]==0 ){
sam[p][ch] = np;
p = par[p];
}
if( p==0 ){
par[np] = 1;
}
else{
int q = sam[p][ch];
if( len[q] == len[p]+1 )par[np] = q;
else{
tot++;
int nq = tot;
len[nq] = len[p]+1;
par[nq] = par[q];
for(int i=0;i<10;i++)sam[nq][i] = sam[q][i];
par[np] = par[q] = nq;
while( p>0 && sam[p][ch]==q ){
sam[p][ch] = nq;
p = par[p];
}
}
}
}
vector<int>edge[2 * MAX_N];
int dfs(int u) {
for(int v : edge[u]) siz[u] += dfs(v);
return siz[u];
}
long long ans ;
void dfs2(int u) {
for(int i=0;i<9;i++) {
int x = sam[u][i];
int y = sam[u][i+1];
while( x > 0 && y > 0 ) {
ans += 1ll * siz[x] * siz[y] * (len[u] - len[par[u]]);
x = sam[x][9];
y = sam[y][0];
}
}
for(int v : edge[u]) dfs2(v);
}
int main() {
int n;
scanf("%d",&n);
scanf("%s",s);
last = tot = 1;
for(int i=0;i<n;i++)sam_extend(s[i] - '0');
// 建 parent 树,dfs 统计每个节点的 siz
// 有的人也习惯用计数排序按 len 排序
for(int i=2;i<=tot;i++)edge[par[i]].push_back(i);
dfs(1);
ans = 0;
len[0] = -1; // 我这里是要让根节点 1 满足 len[1] - len[0] = 1,不然要出问题;
par[1] = 0;
dfs2(1);
printf("%lld\n",ans);
}
顺带贴一下用来画 S A M SAM SAM 的 debug 代码 QAQ
void print_sam(){
vector<int>edge[20];
for(int i=2;i<=tot;i++)edge[par[i]].push_back(i);
for(int i=1;i<=tot;i++) {
printf("child %d :",i); for(int u : edge[i])printf(" %d",u); printf("\n");
}
for(int i=1;i<=tot;i++) {
printf("sam %d, len=%d, siz = %d :\n",i,len[i],siz[i]);
for(int j=0;j<10;j++) {
if( sam[i][j] > 0 ) {
printf(" %c -> %d\n",'0'+j,sam[i][j]);
}
}
}
}