大白书上的代码有点问题,主要是重复的后缀无法正常排序(如aaaaaaa,排序后是乱序的3456120),因此不利于用二分查找输出所有结果。然后最长公共祖先LCP部分的代码越界访问,导致height[0]不是0。所以自己打了一份代码供自己参考。
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 100010;
char s[maxn],ss[maxn];
int sa[maxn],t1[maxn],t2[maxn],c[maxn];
int l;
void get_sa(int m)
{
int *x=t1,*y=t2;
l=strlen(s);
for(int i=0;i<l;i++)
x[sa[i]=i]=s[i]-'a';
for(int i=0;i<=l;i=(i?i<<1:i+1))
{
int p=0;
for(int j=l-i;j<l;j++) y[p++]=j;
for(int j=0;j<l;j++) if(sa[j]>=i) y[p++]=sa[j]-i;
for(int j=0;j<m;j++) c[j]=0;
for(int j=l-1;j>=0;j--) c[x[y[j]]]++;
for(int j=1;j<m;j++) c[j]+=c[j-1];
for(int j=l-1;j>=0;j--) sa[--c[x[y[j]]]]=y[j];
swap(x,y);
p=1,x[sa[0]]=0;
for(int j=1;j<l;j++)
x[sa[j]]=y[sa[j-1]]==y[sa[j]]&&(sa[j-1]+i>=l?-1:y[sa[j-1]+i])==(sa[j]+i>=l?-1:y[sa[j]+i])?p-1:p++;
if(p>=l) return;
m=p;
}
}
int m;
int cmp_str(int p,char* ss)
{
return strncmp(ss,s+sa[p],m);
}
int find(char* ss)
{
m=strlen(ss);
int L=0,R=l-1;
while(L<=R)
{
int M=L+(R-L)/2;
int ans=cmp_str(M,ss);
if(!ans) return M;
else if(ans>0) L=M+1;
else R=M-1;
}
return -1;
}
void find_all(char* ss)
{
m=strlen(ss);
int L1=0,R1=l;
while(L1<R1)
{
int M=L1+(R1-L1)/2;
if(cmp_str(M,ss)<=0) R1=M;
else L1=M+1;
}
int L2=0,R2=l;
while(L2<R2)
{
int M=L2+(R2-L2)/2;
if(cmp_str(M,ss)>=0) L2=M+1;
else R2=M;
}
printf("{");
for(int i=L1;i<L2;i++)
printf("%d",sa[i]);
puts("}");
}
int rank[maxn],height[maxn];
void get_hight()
{
for(int i=0;i<l;i++) rank[sa[i]]=i;
int k=0;
height[0]=0;
for(int i=0;i<l;i++)
{
if(!rank[i]) continue;
if(k) k--;
int j=sa[rank[i]-1];
while(s[i+k]==s[j+k]) k++;
height[rank[i]]=k;
}
printf("height:");
for(int i=0;i<l;i++)
printf("%d ",height[i]);
puts("");
}
const int POW=20;
int st[maxn][POW];
void get_st()
{
for(int i=0;i<l;i++)
st[i][0]=height[i];
for(int i=1;(1<<i)<=l;i++)
for(int j=0;j+(1<<i)-1<l;j++)
st[j][i]=min(st[j][i-1],st[j+(1<<(i-1))][i-1]);
}
int rmq(int i,int j)
{
i=rank[i],j=rank[j];
if(i>j) swap(i,j);
i++;
int k=0;
while((1<<(k+1))<=j-i+1) k++;
return min(st[i][k],st[j-(1<<k)+1][k]);
}
int main()
{
scanf("%s",s);
l=strlen(s);
get_sa(26);
get_hight();
get_st();
//输出sa数组。
/*
for(int i=0;i<l;i++) printf("%d ",sa[i]);
puts("");
*/
//输入模式串,找到一个匹配的后缀
/*
while(~scanf("%s",ss)) printf("找到后缀%d\n",find(ss));
*/
//输入模式串,找出所有匹配的后缀。
/*
while(~scanf("%s",ss)) printf("后缀%d\n集合",sa[find(ss)]),find_all(ss);
*/
//输入后缀i,j,输出LCP长度。
/*
int i,j;
while(~scanf("%d %d",&i,&j)) printf("后缀%d和后缀%d的LCP的长度为:%d\n",i,j,rmq(i,j));
*/
return 0;
}