1.判断字符串出现次数
SAM中,每个节点代表的是endpos集合相同的一些状态
endpos是字串末尾的集合,故endpos的大小就是出现次数
考虑求出endpos集合大小
对于parent tree上的一个节点,他的子节点每出现一次,这个点就会相应的出现一次(因为这个点一定是他的子节点的子串)
故每个点出现次数是它的子节点出现次数和+1(+1这个节点所代表字符串本身出现的第一次)
void dfs(int x){
for(int i=0;i<tu[x].size();i++){
int p=tu[x][i];
dfs(p);
mes[x].cnt+=mes[p].cnt;
}
if(mes[x].cnt>1) ans=max(ans,1ll*mes[x].cnt*mes[x].len);
}
void get_ans(){
for(int i=1;i<=sz;i++) tu[mes[i].fa].push_back(i);
dfs(1);
printf("%lld",ans);
}
2.判断不同子串总个数
考虑parent tree的特殊性质
每个节点所代表的是endpos相同的点集,故可以知道每个节点代表的所有字串长度不同且完全覆盖[minlen,maxlen]这个区间(因为这个点的endpos集合是minlen-1的子集,maxlen+1是这个集合的子集)
所以求出每个节点代表的最长串长度maxlen(x),减掉他的父亲的maxlen(fa),就是这个点所代表的长度区间
动态维护不同字串个数的代码:
void insert(int ch){
int now=++sz;
mes[now].len=mes[las].len+1;
while(las&&!mes[las].to[ch]) mes[las].to[ch]=now,las=mes[las].fa;
if(!las) mes[now].fa=1,ans+=mes[now].len;
else{
int to=mes[las].to[ch];
if(mes[las].len+1==mes[to].len) mes[now].fa=to,ans=ans+mes[now].len-mes[to].len;
else{
int np=++sz;
mes[np]=mes[to];
mes[np].len=mes[las].len+1;
ans=ans+mes[np].len-mes[mes[np].fa].len;
while(las&&mes[las].to[ch]==to) mes[las].to[ch]=np,las=mes[las].fa;
ans=ans-(mes[to].len-mes[mes[to].fa].len);
mes[to].fa=mes[now].fa=np;
ans=ans-mes[np].len*2+mes[to].len+mes[now].len;
}
}
las=now;
printf("%lld\n",ans);
}
3.1求本质不同第k大
类比一下求平衡树rank的方法
因为SAM是个DAG,考虑预处理出来每个点以及后续所有节点的总个数sz
每次遍历字符集,若k>sz就让k减掉那个size,否则说明在这一位上就是这个字符,进入这个节点继续查询即可
每次开始-1是减去现在遍历的这个状态所代表的字符串,k到0表示查询到了
dfs实现代码
void dfs(int x){
if(mes[x].sz) return ;
mes[x].sz=1;
for(char ch='a';ch<='z';ch++){
int p=mes[x].nxt[ch];
if(p) dfs(p),mes[x].sz+=mes[p].sz;
}
}
void get_kth(int x,int k){
k--;
if(!k) return ;
for(char ch='a';ch<='z';ch++){
int p=mes[x].nxt[ch];
if(!p) continue;
if(k<=mes[p].sz){
putchar(ch),get_kth(p,k);
break;
}
if(k>mes[p].sz) k-=mes[p].sz;
}
}
void get(int k){
anslen=0;
if(k>mes[1].sz-1){
cout<<-1;
return ;
}
get_kth(1,k);
putchar('\n');
}
3.2求所有字串第k大(不同位置算多个
能求出每个字串出现了多少次,能求出第k大
缝合一下
本来每个节点的sz值为1,现在令每个节点的sz值是endpos集合大小就行
开始时k-1变成k-endpos
这个用bfs实现的
void dfs(int x){
for(int ch=0;ch<26;ch++){
int p=mes[x].nxt[ch];
if(p&&!vis[p]) vis[p]=true,dfs(p);
mes[x].sz+=mes[p].sz;
}
}
void get_endpos(int x){
for(int i=0;i<tu[x].size();i++){
int p=tu[x][i];
get_endpos(p);
mes[x].endpos+=mes[p].endpos;
}
if(x!=1) mes[x].sz=mes[x].endpos;
}
void get_sz(){
if(op){
for(int i=2;i<=sz;i++) tu[mes[i].fa].push_back(i);
get_endpos(1);
}
dfs(1);
if(!op) mes[1].sz--;
}
void get_kth(int k){
int x=1;
if(k>mes[x].sz){
printf("-1");
return ;
}
while(true){
if(x!=1) k-=(op?mes[x].endpos:1);
if(k<=0) break;
for(int ch=0;ch<26;ch++){
int p=mes[x].nxt[ch];
if(!p) continue;
if(k>mes[p].sz) k-=mes[p].sz;
else{
x=p;putchar(ch+'a');
break;
}
}
}
}
4.求循环移位(最小表示法
将一个字符串开头字符移到末尾,常见的处理方式是复制两遍字符串
因为要的是最小表示,故贪心每次走最小字符,求出长度为字符串原长的答案
//用map存的集合,容易找出最小的边
void get(int l){
int len=0,x=1;
while(len<l){
printf("%d ",(*mes[x].nxt.begin()).first);
x=(*mes[x].nxt.begin()).second;
len++;
}
}
5.1求两个字符串的最长公共字串
先将一个字符串构建SAM
然后依次去查询令一个字符串的每一个字符
若有这个字符,答案长度+1,否则跳parent tree边知道根节点(答案变成新节点所代表的子串集合中最长的字符串长度)
void query(char *s){
int len=strlen(s+1);
int p=1,ans=0,ansl=0;
for(int i=1;i<=len;i++){
while(p!=1&&!mes[p].nxt[s[i]]) p=mes[p].fa,ansl=mes[p].len;
if(mes[p].nxt[s[i]]) p=mes[p].nxt[s[i]],ansl++;
ans=max(ans,ansl);
}
printf("%d",ans);
}
5.2 求多个字符串的最长公共子串
先思考一件事情
每个节点经过一次,那么其parent tree上的所有祖先相当于全部都访问过一遍(因为是它的子串
按照求单个相同子串的方法,在每个经过的节点上用现在的答案来更新这个节点的答案
然后dfs一遍parent tree,用祖先节点加上子节点的答案
将每个节点每次答案取最小值,就是这个节点所代表公共子串的长度
将每个节点所代表的长度取最大值,就是最长公共子串
还有一点,每次的复杂度是O(max(文本串长度,模式串长度)),所以最好用最短的串来构造SAM,这样保证复杂度最大为字符串总长
void prepare(){
for(int i=1;i<=sz;i++) sma[i]=mes[i].len;//最小值先赋为最大长度
for(int i=1;i<=sz;i++){
if(mes[i].fa) tu[mes[i].fa].push_back(i);//parent tree
}
}
void dfs(int x){
for(int i=0;i<tu[x].size();i++){
int p=tu[x][i];
dfs(p),mx[x]=min(max(mx[x],mx[p]),mes[x].len);//处理每次的最大值
}
}
void get(char *s){
memset(mx,0,sizeof(mx));
int len=strlen(s+1);
int p=1,now=0;
for(int i=1;i<=len;i++){
while(p!=1&&!mes[p].nxt[s[i]]) p=mes[p].fa,now=mes[p].len;
if(mes[p].nxt[s[i]]) now++,p=mes[p].nxt[s[i]],mx[p]=max(mx[p],now);
}
dfs(1);
for(int i=1;i<=sz;i++) sma[i]=min(sma[i],mx[i]);//更新答案
}
void get_ans(){
int ans=0;
for(int i=1;i<=sz;i++) ans=max(ans,sma[i]);
cout<<ans<<endl;
}