题意:
给定一个长为 n 的串 s,再有 q 次询问,每次询问子串 s[l…r] 在原串中第 k 次出现的位置,若无答案,输出 -1。(n, q, k<= 1e5)
链接:
https://cn.vjudge.net/problem/HDU-6704
题解:
对于一个子串 s[l…r] 的出现位置,可以转化为求所有与 suffix[l] 匹配长度大于等于 r - l + 1 的后缀,这个可以用后缀数组的 height 数组二分求得一个可行区间。至于第 k 次出现,即为求满足条件的后缀区间里面的第 k 小数,主席树可解。
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define pb push_back
#define sz(a) ((int)a.size())
#define mem(a, b) memset(a, b, sizeof a)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
char s[maxn];
int sa[maxn], rk[maxn], hi[maxn];
int tax[maxn], tp[maxn], mn[maxn][21];
int sum[maxn * 40], ls[maxn * 40], rs[maxn * 40];
int rt[maxn], tot;
int n, m, q;
void rsort(int n, int m){
for(int i = 0; i <= m; ++i) tax[i] = 0;
for(int i = 1; i <= n; ++i) ++tax[rk[tp[i]]];
for(int i = 1; i <= m; ++i) tax[i] += tax[i - 1];
for(int i = n; i >= 1; --i) sa[tax[rk[tp[i]]]--] = tp[i];
}
int cmp(int x, int y, int len, int n){
return tp[x] == tp[y] && x + len <= n && y + len <= n && tp[x + len] == tp[y + len];
}
void da(int n, int m){
for(int i = 1; i <= n; ++i) tp[i] = i, rk[i] = s[i];
rsort(n, m);
for(int len = 1, p = 0; p < n; len <<= 1, m = p){
p = 0;
for(int i = n - len + 1; i <= n; ++i) tp[++p] = i;
for(int i = 1; i <= n; ++i) if(sa[i] > len) tp[++p] = sa[i] - len;
rsort(n, m);
for(int i = 1; i <= n; ++i) tp[i] = rk[i];
p = rk[sa[1]] = 1;
for(int i = 2; i <= n; ++i) rk[sa[i]] = cmp(sa[i - 1], sa[i], len, n) ? p : ++p;
}
int k = 0;
for(int i = 1; i <= n; ++i){
if(k) --k;
int j = sa[rk[i] - 1];
while(s[i + k] == s[j + k]) ++k;
hi[rk[i]] = k;
}
}
void init(){
for(int i = 1; i <= n; ++i) mn[i][0] = hi[i];
int lim = log(n) / log(2);
for(int j = 1; j <= lim; ++j){
for(int i = 1; i <= n; ++i){
if(i + (1 << j) - 1 > n) break;
mn[i][j] = min(mn[i][j - 1], mn[i + (1 << (j - 1))][j - 1]);
}
}
}
int ask(int l, int r){
if(l > r) swap(l, r); ++l;
int k = log(r - l + 1) / log(2);
return min(mn[l][k], mn[r - (1 << k) + 1][k]);
}
int getL(int id, int len){
int l = 1, r = id - 1, mid, ret = id;
while(l <= r){
mid = gmid;
if(ask(mid, id) >= len) r = mid - 1, ret = mid;
else l = mid + 1;
}
return ret;
}
int getR(int id, int len){
int l = id + 1, r = n, mid, ret = id;
while(l <= r){
mid = gmid;
if(ask(id, mid) >= len) l = mid + 1, ret = mid;
else r = mid - 1;
}
return ret;
}
void update(int l, int r, int &rt, int pre, int pos, int val){
rt = ++tot;
sum[rt] = sum[pre] + val, ls[rt] = ls[pre], rs[rt] = rs[pre];
if(l == r) return;
int mid = gmid;
if(pos <= mid) update(l, mid, ls[rt], ls[pre], pos, val);
else update(mid + 1, r, rs[rt], rs[pre], pos, val);
}
int query(int l, int r, int rt, int pre, int k){
if(l == r) return sum[rt] - sum[pre] == k ? l : -1;
int mid = gmid, d = sum[ls[rt]] - sum[ls[pre]];
if(d >= k) return query(l, mid, ls[rt], ls[pre], k);
else return query(mid + 1, r, rs[rt], rs[pre], k - d);
}
int main() {
// ios::sync_with_stdio(0); cin.tie(0);
int t; scanf("%d", &t);
while(t--){
scanf("%d%d%s", &n, &q, s + 1);
m = 255, tot = 0;
da(n, m); init();
// for(int i = 1; i <= n; ++i) printf("%d %d %d %d\n", i, sa[i], rk[i], hi[i]);
for(int i = 1; i <= n; ++i){
update(1, n, rt[i], rt[i - 1], sa[i], 1);
}
while(q--){
int l, r, k; scanf("%d%d%d", &l, &r, &k);
int len = r - l + 1, id = rk[l];
int L = getL(id, len);
int R = getR(id, len);
int ret = query(1, n, rt[R], rt[L - 1], k);
printf("%d\n", ret);
}
}
return 0;
}
另一种姿势是用后缀自动机,具体为倍增找到 s[l…r] 在自动机上对应的结点(len >= r - l + 1),在其子树内求第 k 小,需要用到线段树合并。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define pb push_back
#define sz(a) ((int)a.size())
#define mem(a, b) memset(a, b, sizeof a)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 2e5 + 5;
const int maxm = 2e6 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
vector<int> G[maxn];
char s[maxn];
int nxt[maxn][26], len[maxn], par[maxn], id[maxn], fa[maxn][21];
int sum[maxn * 40], ls[maxn * 40], rs[maxn * 40], rt[maxn];
int n, m, last, cnt, tot;
inline int add(int l){
mem(nxt[++cnt], 0), len[cnt] = l; return cnt;
}
void init(){
cnt = 0, last = add(0);
}
void insert(char ch, int x){
int t = ch - 'a', p = last, np;
last = np = add(len[p] + 1); id[x] = np;
while(p && !nxt[p][t]) nxt[p][t] = np, p = par[p];
if(!p) { par[np] = 1; return; }
int q = nxt[p][t];
if(len[p] + 1 == len[q]) { par[np] = q; return; }
int nq = add(len[p] + 1);
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[np] = par[q] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
}
void update(int l, int r, int &rt, int pos, int val){
if(!rt) rt = ++tot;
if(l == r){ sum[rt] += val; return; }
int mid = gmid;
if(pos <= mid) update(l, mid, ls[rt], pos, val);
else update(mid + 1, r, rs[rt], pos, val);
sum[rt] = sum[ls[rt]] + sum[rs[rt]];
}
int query(int l, int r, int rt, int k){
if(l == r) return l;
int mid = gmid, d = sum[ls[rt]];
if(d >= k) return query(l, mid, ls[rt], k);
else return query(mid + 1, r, rs[rt], k - d);
}
int merge(int l, int r, int u, int v){
if(!u || !v) return u + v;
int t = ++tot;
if(l == r){
sum[t] = sum[u] + sum[v];
return t;
}
int mid = gmid;
ls[t] = merge(l, mid, ls[u], ls[v]);
rs[t] = merge(mid + 1, r, rs[u], rs[v]);
sum[t] = sum[ls[t]] + sum[rs[t]];
return t;
}
void dfs(int u, int f){
fa[u][0] = f;
for(int i = 1; i <= 20; ++i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i];
dfs(v, u);
rt[u] = merge(1, n, rt[u], rt[v]);
}
}
int getPos(int u, int l){
for(int i = 20; i >= 0; --i){
if(len[fa[u][i]] >= l) u = fa[u][i];
}
return u;
}
int main() {
// ios::sync_with_stdio(0); cin.tie(0);
int t; scanf("%d", &t);
while(t--){
scanf("%d%d%s", &n, &m, s + 1);
for(int i = 1; i <= cnt; ++i) G[i].clear(), rt[i] = 0;
for(int i = 1; i <= tot; ++i) ls[i] = rs[i] = sum[i] = 0;
tot = 0, init();
for(int i = 1; i <= n; ++i) insert(s[i], i);
for(int i = 2; i <= cnt; ++i) G[par[i]].pb(i);
for(int i = 1; i <= n; ++i) update(1, n, rt[id[i]], i, 1);
dfs(1, 0);
while(m--){
int l, r, k; scanf("%d%d%d", &l, &r, &k); l = r - l + 1;
int u = getPos(id[r], l);
int ret = sum[rt[u]] >= k ? query(1, n, rt[u], k) : -1;
printf("%d\n", ret != -1 ? ret - l + 1 : -1);
}
}
return 0;
}