题目描述
传送门
题目大意:有一个长度为 n n n 的字符串, 你需要把它分成不超过k 段, 设第 i 段的字典序最大的子串为 Ci , 现在求 Ci 中字典序最大的那个最小能是多少。
题解
看到最小值最大,比较容易想到的思路就是二分。
对于字符串建立后缀数组,字符串中所有的本质不同的子串的个数是
∑i=1nn−sa[i]+1−height[i]
我们可以二分子串,判断是否可行。首先用与统计类似的方式找到第mid个子串[l,r],然后对字符串贪心的进行划分。从后向前,如果[i,last]的字典序大于[l,r],那么i后面就需要划分一刀,如果最后划分成的数量超过k,那么就不能满足。需要特别注意,如果s[i]>s[l],那么直接判断不能满足。
对于判断两个串的字典序大小,先用st表找出两个串的lcp,然后分类讨论一下就可以了。
需要特别注意的是查询的时候如果x=y,那么需要特判。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 600003
#define LL long long
using namespace std;
int n,a[N],sa[N],*y,*x,xx[N],yy[N],rank[N],st[N][21],L[N],height[N],ls,rs,ansl,ansr,K,v[N];
char s[N];
int cmp(int i,int j,int k)
{
return y[i]==y[j]&&(i+k>n?-1:y[i+k])==(j+k>n?-1:y[j+k]);
}
void get_SA()
{
x=xx; y=yy; int m=30; int p;
for (int i=1;i<=n;i++) v[x[i]=a[i]]++;
for (int i=1;i<=m;i++) v[i]+=v[i-1];
for (int i=n;i>=1;i--) sa[v[x[i]]--]=i;
for (int k=1;k<=n;k<<=1) {
p=0;
for (int i=n-k+1;i<=n;i++) y[++p]=i;
for (int i=1;i<=n;i++)
if (sa[i]>k) y[++p]=sa[i]-k;
for (int i=0;i<=m;i++) v[i]=0;
for (int i=1;i<=n;i++) v[x[y[i]]]++;
for (int i=1;i<=m;i++) v[i]+=v[i-1];
for (int i=n;i>=1;i--) sa[v[x[y[i]]]--]=y[i];
swap(x,y);
p=2; x[sa[1]]=1;
for (int i=2;i<=n;i++)
x[sa[i]]=(cmp(sa[i],sa[i-1],k)?p-1:p++);
if (p>n) break;
m=p+1;
}
for (int i=1;i<=n;i++) rank[sa[i]]=i;
p=0;
for (int i=1;i<=n;i++) {
if (rank[i]==1) continue;
int j=sa[rank[i]-1];
while (j+p<=n&&i+p<=n&&a[j+p]==a[i+p]) p++;
height[rank[i]]=p;
p=max(p-1,0);
}
for (int i=1;i<=n;i++) st[i][0]=height[i];
for (int i=1;i<=19;i++)
for (int j=1;j<=n;j++)
if (j+(1<<i)-1<=n)
st[j][i]=min(st[j][i-1],st[j+(1<<i-1)][i-1]);
int j=0;
for (int i=1;i<=n;i++) {
if ((1<<j+1)<=i) j++;
L[i]=j;
}
}
void get_kth(LL t)
{
for (int i=1;i<=n;i++)
if (t>(n-height[i]-sa[i]+1)) t-=(n-height[i]-sa[i]+1);
else {
ls=sa[i];
rs=sa[i]+height[i]+t-1;
return;
}
}
int calc(int x,int y)
{
if (x==y) return n-sa[x]+1;
if (x>y) swap(x,y);
int k=L[y-x]; x++;
return min(st[x][k],st[y-(1<<k)+1][k]);
}
bool compare(int l1,int r1,int l2,int r2)
{
int len1=r1-l1+1; int len2=r2-l2+1; int lcp=calc(rank[l1],rank[l2]);
if (lcp>=len1) return len1<=len2;
if (lcp>=len2) return 0;
return a[l1+lcp]<=a[l2+lcp];
}
bool check(int mid)
{
int last=n; int cnt=1;
for (int i=n;i>=1;i--){
if (a[i]>a[ls]) return 0;
if (!compare(i,last,ls,rs)) {
cnt++; last=i;
if (cnt>K) return 0;
}
}
return 1;
}
int main()
{
scanf("%d",&K); scanf("%s",s+1);
n=strlen(s+1);
for (int i=1;i<=n;i++) a[i]=s[i]-'a'+1;
get_SA();
LL l=1; LL r=0;
for (int i=1;i<=n;i++) r+=(n-height[i]-sa[i]+1);
while (l<=r) {
LL mid=(l+r)/2;
get_kth(mid);
if (check(mid)) ansl=ls,ansr=rs,r=mid-1;
else l=mid+1;
}
for (int i=ansl;i<=ansr;i++) printf("%c",s[i]);
printf("\n");
}