二分原串的所有子串最多O(n^2)个
求一个子串的排名和由排名求子串都可以拿height数组乱搞
(如果多组询问的话还可以二分)
判断的话也是贪心一下就好了
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
using namespace std;
typedef long long LL;
inline int read()
{
int x=0;bool f=0;char c=getchar();
for (;c<'0'||c>'9';c=getchar()) f=c=='-'?1:0;
for (;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
return f?-x:x;
}
const int N=100010,P=20,oo=0x3f3f3f3f;
int K,n,c[N],na0[N],na1[N],sa[N],rk[N],h[N],mn[N][P],log[N];
LL sum;char str[N];
void cal_sa(int n,int m)
{
int *x=na0,*y=na1;
for (int i=0;i<n;i++) c[x[i]=str[i]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n-1;~i;i--) sa[--c[x[i]]]=i;
for (int k=1,p=1;p<n;k<<=1,m=p)
{
p=0;
for (int i=n-k;i<n;i++) y[p++]=i;
for (int i=0;i<n;i++) if (sa[i]>=k) y[p++]=sa[i]-k;
memset(c,0,sizeof(int)*(m+1));
for (int i=0;i<n;i++) c[x[y[i]]]++;
for (int i=1;i<=m;i++) c[i]+=c[i-1];
for (int i=n-1;~i;i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);p=1;x[sa[0]]=0;
for (int i=1;i<n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?p-1:p++;
}
for (int i=0;i<n;i++) rk[sa[i]]=i;
for (int i=0,j,k=0;i<n-1;h[rk[i++]]=k)
for ((k?k--:0),j=sa[rk[i]-1];str[i+k]==str[j+k];k++);
}
void build_st()
{
for (int i=2;i<=n;i++) log[i]=log[i>>1]+1;
for (int i=1;i<=n;i++) mn[i][0]=h[i];
for (int i=1;i<P;i++)
for (int j=1;j+(1<<i)-1<=n;j++)
mn[j][i]=min(mn[j][i-1],mn[j+(1<<i-1)][i-1]);
}
inline int lcp(int a,int b)
{
if (a==b) return n-a;
a=rk[a];b=rk[b];
if (a>b) swap(a,b);a++;
int lg=log[b-a];
return min(mn[a][lg],mn[b-(1<<lg)+1][lg]);
}
inline bool cmp(int l1,int r1,int l2,int r2)
{
int l=min(lcp(l1,l2),min(r1-l1+1,r2-l2+1));
if (r1-l1+1==l) return 1;
if (r2-l2+1==l) return 0;
return str[l1+l]<=str[l2+l];
}
int _s,_l;
inline void find(LL x)
{
x;
for (int i=1,rec;i<=n;i++,x-=rec)
{
rec=n-sa[i]-h[i];
if (x<=rec) {_s=sa[i];_l=h[i]+x;break;}
}
}
bool check(LL m)
{
find(m);int cnt=1,last=n-1;
for (int i=n-1;~i;i--)
{
if (str[i]>str[_s]) return 0;
if (!cmp(i,last,_s,_s+_l-1)) cnt++,last=i;
if (cnt>K) return 0;
}
return 1;
}
int main()
{
K=read();
scanf("%s",str);n=strlen(str);
cal_sa(n+1,300);
build_st();
for (int i=1;i<=n;i++) sum+=n-sa[i]-h[i];
check(1);
LL l=0,r=sum,mid;
while (l+1<r)
{
if (check(mid=l+r>>1)) r=mid;
else l=mid;
}
find(r);
for (int i=_s;i<_s+_l;i++) putchar(str[i]);
puts("");
return 0;
}