一开始写了一发后缀数组。
T=0的时候直接把所有的n-sa[i]+1-h[i]加到一起一直到K就行了。
T=1时枚举每一位是什么,二分求一下范围,处理一个n-sa[i]+1的前缀和。
本来以为会T,结果并没有。。。
#include <bits/stdc++.h>
using namespace std;
#define N 510000
#define ll long long
char s[N];
int n,T;
int h[N],sa[N],rank[N],tr[N],has[N];
ll K,sum[N];
int cmp(int x,int y,int k)
{
if(x+k>n||y+k>n)return 0;
return rank[x]==rank[y]&&rank[x+k]==rank[y+k];
}
void getsa()
{
int i,cnt;
for(i=1;i<=n;i++)has[s[i]-'a'+1]++;
for(i=1,cnt=0;i<=26;i++)if(has[i])tr[i]=++cnt;
for(i=1;i<=26;i++)has[i]+=has[i-1];
for(i=1;i<=n;i++)rank[i]=tr[s[i]-'a'+1],sa[has[s[i]-'a'+1]--]=i;
for(int k=1;cnt!=n;k<<=1)
{
for(i=1;i<=n;i++)has[i]=0;
for(i=1;i<=n;i++)has[rank[i]]++;
for(i=1;i<=n;i++)has[i]+=has[i-1];
for(i=n;i>=1;i--)if(sa[i]>k)tr[sa[i]-k]=has[rank[sa[i]-k]]--;
for(i=1;i<=k;i++)tr[n-i+1]=has[rank[n-i+1]]--;
for(i=1;i<=n;i++)sa[tr[i]]=i;
for(i=1,cnt=0;i<=n;i++)tr[sa[i]]=cmp(sa[i],sa[i-1],k) ? cnt:++cnt;
for(i=1;i<=n;i++)rank[i]=tr[i];
}
for(i=1;i<=n;i++)
{
if(rank[i]==1)continue;
for(int j=max(h[rank[i-1]]-1,1);;j++)
{
if(s[i+j-1]==s[sa[rank[i]-1]+j-1])h[rank[i]]=j;
else break;
}
}
}
void print(int l,int r)
{
for(int i=l;i<=r;i++)
putchar(s[i]);
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%s",s+1);n=strlen(s+1);
getsa();
scanf("%d%lld",&T,&K);
if(T==0)
{
for(int i=1;i<=n;i++)
{
int t=n-sa[i]+1-h[i];
if(K<=t)return print(sa[i],sa[i]+h[i]+K-1),0;
K-=t;
}
}
else
{
for(int i=1;i<=n;i++)
sum[i]=sum[i-1]+n-sa[i]+1;
if(sum[n]<K)return puts("-1"),0;
int l1=1,r1=n;
for(int i=1;i<=n;i++)
{
int l2=l1;
for(int j='a';j<='z';j++)
{
int l=l2,r=r1;
while(l<=r)
{
int mid=(l+r)>>1;
if(s[sa[mid]+i-1]>j)r=mid-1;
else l=mid+1;
}
ll t=sum[r]-sum[l2-1]-(ll)(r-l2+1)*(i-1);
if(t>=K)
{
if(r-l2+1>=K)return print(sa[l2],sa[l2]+i-1),0;
l1=l2;r1=r;K-=r-l2+1;break;
}
l2=r+1;K-=t;
}
if(n-sa[l1]+1==i)l1++;
}
}
puts("-1");
return 0;
}
看大家都写后缀自动机,顺便写了一发。
#include <bits/stdc++.h>
using namespace std;
#define N 1100000
char s[N];
int n,T,K;
struct SAM
{
int trs[N][26],fa[N],len[N],last,cnt;
int v[N],sum[N],pos[N],has[N];
void init(){last=cnt=1;}
void insert(int x)
{
int p=last,np=++cnt,q,nq;
last=np;len[np]=len[p]+1;v[np]=1;
for(;p&&!trs[p][x];p=fa[p])trs[p][x]=np;
if(!p)fa[np]=1;
else
{
q=trs[p][x];
if(len[q]==len[p]+1)fa[np]=q;
else
{
fa[nq=++cnt]=fa[q];
len[nq]=len[p]+1;
memcpy(trs[nq],trs[q],sizeof(trs[q]));
fa[q]=fa[np]=nq;
for(;trs[p][x]==q;p=fa[p])trs[p][x]=nq;
}
}
}
void cal()
{
for(int i=1;i<=cnt;i++)has[len[i]]++;
for(int i=1;i<=n;i++)has[i]+=has[i-1];
for(int i=1;i<=cnt;i++)pos[has[len[i]]--]=i;
for(int i=cnt,t;i>=1;i--)
{
t=pos[i];
if(T==0)v[t]=1;
else v[fa[t]]+=v[t];
}
v[1]=0;
for(int i=cnt,t;i>=1;i--)
{
t=pos[i];sum[t]=v[t];
for(int j=0;j<26;j++)
sum[t]+=sum[trs[t][j]];
}
}
void dfs(int x,int K)
{
if(v[x]>=K)return;
K-=v[x];
for(int i=0;i<26;i++)
{
if(K<=sum[trs[x][i]])
{
putchar(i+'a');
dfs(trs[x][i],K);
return;
}
K-=sum[trs[x][i]];
}
}
}sam;
int main()
{
//freopen("tt.in","r",stdin);
scanf("%s",s+1);
n=strlen(s+1);sam.init();
for(int i=1;i<=n;i++)
sam.insert(s[i]-'a');
scanf("%d%d",&T,&K);
sam.cal();
if(K>sam.sum[1])return puts("-1"),0;
sam.dfs(1,K);
return 0;
}