last UPD at 2018.3.14 把之前手滑的地方修正了
说在前面
被这题折磨致死…
第一次写SAM,理解不够,于是被各种细节坑上天=A=
还要复习会考,伤心
题目
BZOJ3998传送门
###题目大意
对于一个给定长度为N的字符串,求它的第K小子串是什么。
输入输出格式
输入格式:
第一行是一个仅由小写英文字母构成的字符串S
第二行为两个整数T和K,T为0则表示不同位置的相同子串算作一个。T=1则表示不同位置的相同子串算作多个。K的意义如题所述。
输出格式:
输出仅一行,为一个数字串,为第K小的子串。如果子串数目不足K个,则输出-1
解法
听说这是一道很基础很基础的题
aaaaaaaaaaaaaaaaaaaaaaaaaaaa土拨鼠叫.jpg
维护方法和AC自动机上维护信息差不多…都是通过parent(AC自动机上是fail)和child累加信息
具体做法,在后缀自动机的每个节点上开两个域cnt和sum,cnt表示该节点 代表的字符串集合 出现了多少次,sum表示该节点 代表的字符串集合 所能到达的字符串总个数。
很明显nd->sum = nd->cnt +
Σ
\Sigma
Σnd->child[i]->sum(后缀自动机上的child)
然后关于cnt的计算
如果重复子串只计算一次(T为0),那么所有节点的cnt都应该是1(这里len不会对cnt带来影响,因为len是向前扩展的。换种理解方式,就是说走过不同的前缀,最后到达的节点可能相同,因为它们结束位置相同,它们的sum和cnt也是一样的)
如果重复子串需要累计(T为1),那么在parent树上nd->cnt+=
Σ
\Sigma
Σnd->child[k]->cnt(注意是parent树上,当前串也肯定会在parent树儿子里再出现的,所以要累加上儿子的。这其实就是right集合的大小)
还有还有,因为在建SAM的时候,节点编号的顺序并不是自动机上或者parent树上的顺序,因此只能通过根据length排序来确定拓扑序,不像AC自动机可以直接逆BFS序统计。
下面是自带大常数的代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
long long K ;
int T , lens , bt[1000005] , so[1000005] ;
char ss[500005] ;
struct Node{
int len ;
long long cnt , sum ;
Node *ch[27] , *par ;
}w[1000005] , *root , *last , *tw = w ;
void GG(){
puts( "-1" ) ;
exit( 0 ) ;
}
void newNode( Node *&nd ){
nd = ++tw ;
nd->len = nd->cnt = nd->sum = 0 ;
memset( nd->ch , 0 , sizeof( nd->ch ) ) ;
nd->par = NULL ;
}
void Insert( char cc ){
Node *now , *tmp = last ;
newNode( now ) ; now->len = last->len + 1 ;
int id = cc - 'a' ;
while( tmp && !tmp->ch[id] ){
tmp->ch[id] = now ;
tmp = tmp->par ;
}
if( !tmp ) now->par = root ;
else{
if( tmp->ch[id]->len == tmp->len + 1 ) {
now->par = tmp->ch[id] ;
} else {
Node *B = tmp->ch[id] , *nB = ++tw ;
*nB = *B ;
now->par = B->par = nB ;
nB->len = tmp->len + 1 ;
while( tmp && tmp->ch[id] == B ){
tmp->ch[id] = nB ;
tmp = tmp->par ;
}
}
}
last = now ;
}
void build(){
last = root ;
for( int i = 0 ; i < lens ; i ++ )
Insert( ss[i] ) ;
}
void count(){
for( int i = tw - w ; i ; i -- ){
Node *nd = w + so[i] ;
if( !T && nd->cnt ) nd->cnt = 1 ;
if( nd != root ){
nd->par->cnt += nd->cnt ;
nd->sum = nd->cnt ;
}
for( int i = 0 ; i < 26 ; i ++ )
if( nd->ch[i] ) nd->sum += nd->ch[i]->sum ;
}
}
void solve(){
newNode( root ) ;
build() ;
for( Node *nd = w+1 ; nd <= tw ; nd ++ ) bt[nd->len] ++ ;
for( int i = 1 ; i <= lens ; i ++ ) bt[i] += bt[i-1] ;
for( Node *nd = tw ; nd != w ; nd -- ) so[ bt[nd->len]-- ] = nd - w ;
Node *tmp = root ;
for( int i = 0 ; i < lens ; i ++ ){
tmp = tmp->ch[ ss[i] - 'a' ] ;
tmp->cnt = 1 ;
}
count() ;
if( root->sum < K ) GG() ;
tmp = root ;
for( int i ; K > 0 ; ){
for( i = 0 ; i < 26 ; i ++ )
if( tmp->ch[i] ){
if( tmp->ch[i]->sum < K ) K -= tmp->ch[i]->sum ;
else break ;
}
printf( "%c" , 'a' + i ) ;
K -= tmp->ch[i]->cnt ;
tmp = tmp->ch[i] ;
}
}
int main(){
scanf( "%s" , ss ) ;
lens = strlen( ss ) ;
scanf( "%d%lld" , &T , &K ) ;
solve() ;
return 0 ;
}