【题目链接】
【思路要点】
- 补档博客,无题解。
【代码】
#include<bits/stdc++.h> using namespace std; #define MAXN 100005 #define MAXLOG 20 #define CSIZE 256 template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } int n, k; char s[MAXN]; int start, length; int rank[MAXN], sa[MAXN], height[MAXN]; int bit[MAXN], st[MAXLOG][MAXN]; int LCP(int x, int y) { if (x == y) return n - x + 1; x = rank[x], y = rank[y]; if (x > y) swap(x, y); int len = y - x, tmp = bit[len]; return min(st[tmp][x], st[tmp][y - (1 << tmp)]); } void suffix() { static int cnt[MAXN], x[MAXN], y[MAXN], rk[MAXN]; memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= n; i++) cnt[s[i]]++; for (int i = 1; i <= CSIZE; i++) cnt[i] += cnt[i - 1]; for (int i = n; i >= 1; i--) sa[cnt[s[i]]--] = i; rank[sa[1]] = 1; for (int i = 2; i <= n; i++) rank[sa[i]] = rank[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]); for (int k = 1; rank[sa[n]] != n; k <<= 1) { for (int i = 1; i <= n; i++) { x[i] = rank[i]; y[i] = (i + k <= n) ? rank[i + k] : 0; } memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= n; i++) cnt[y[i]]++; for (int i = 1; i <= n; i++) cnt[i] += cnt[i - 1]; for (int i = n; i >= 1; i--) rk[cnt[y[i]]--] = i; memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= n; i++) cnt[x[i]]++; for (int i = 1; i <= n; i++) cnt[i] += cnt[i - 1]; for (int i = n; i >= 1; i--) sa[cnt[x[rk[i]]]--] = rk[i]; rank[sa[1]] = 1; for (int i = 2; i <= n; i++) rank[sa[i]] = rank[sa[i - 1]] + (x[sa[i]] != x[sa[i - 1]] || y[sa[i]] != y[sa[i - 1]]); } int now = 0; for (int i = 1; i <= n; i++) { if (now) now--; for (int j = sa[rank[i] + 1]; s[j + now] == s[i + now]; now++); height[rank[i]] = now; } for (int i = 1; i <= n; i++) { st[0][i] = height[i]; bit[i] = bit[i - 1]; if (i >= (1 << bit[i] + 1)) bit[i]++; } for (int p = 1; p < MAXLOG; p++) for (int i = 1, j = (1 << p - 1) + 1; j <= n; i++, j++) st[p][i] = min(st[p - 1][i], st[p - 1][j]); } long long getl() { long long ans = 1; for (int i = 1; s[sa[i]] != s[sa[n]]; i++) ans += n - sa[i] + 1 - height[i]; return ans; } long long getr() { long long ans = 0; for (int i = 1; i <= n; i++) ans += n - sa[i] + 1 - height[i]; return ans; } void gett(long long rk) { int pos = 1; for (pos = 1; n - sa[pos] + 1 < rk; pos++) rk -= n - sa[pos] + 1 - height[pos]; start = sa[pos]; length = rk; } bool check(int k) { if (length == 1) return false; int ans = 0, endpoint = n; for (int i = 1; i <= n; i++) { int tmp = LCP(i, start); if (tmp >= length) endpoint = min(endpoint, i + length - 2); else if (s[i + tmp] > s[start + tmp]) endpoint = min(endpoint, i + tmp - 1); if (i == endpoint) { ans++; endpoint = n; } } return ans <= k; } int main() { read(k); scanf("\n%s", s + 1); n = strlen(s + 1); suffix(); long long l = getl(), r = getr(); while (l < r) { long long mid = (l + r + 1) / 2; gett(mid); if (check(k)) r = mid - 1; else l = mid; } gett(l); s[start + length] = 0; printf("%s\n", s + start); return 0; }