题意:
给定一个字符串 s s s,有 q q q 次询问,每次询问 l , r , t l, r, t l,r,t,询问 s [ l ⋯ r ] s[l \cdots r] s[l⋯r] 所有子串中比 t t t 字典序大的那些串里面的字典序最小的一个。 ( ∣ s ∣ ≤ 1 0 5 , q , ∑ ∣ t i ∣ ≤ 2 × 1 0 5 ) (|s| \leq 10^5, q,~\sum |t_i| \leq 2×10^5) (∣s∣≤105,q, ∑∣ti∣≤2×105)
链接:
https://codeforces.com/problemset/problem/1037/H
解题思路:
先考虑 l = 1 , r = n l = 1, r = n l=1,r=n 的情况,再扩展。涉及所有子串的查询,那么先建立 S A M SAM SAM,最终的答案串一定形如 t 1 t 2 ⋯ t p + c t_1t_2\cdots t_p + c t1t2⋯tp+c,其中 p p p 极大,且 c > t p + 1 c \gt t_{p + 1} c>tp+1, c c c 极小。那么让 t t t 在自动机上跑,每次检查最小可能的 c c c,并更新答案,那么子问题变成判断 t [ 1 ⋯ p ] t[1\cdots p] t[1⋯p] 拼接上 c c c 后是否在自动机上,即对应结点是否存在。
加上区间限制,那么就是判断 t [ 1 ⋯ p ] + c t[1 \cdots p] + c t[1⋯p]+c 是否在区间 [ l , r ] [l, r] [l,r] 中,线段树合并维护每个结点的 r i g h t right right 集合,然后再判断即可。(主席树维护也行)
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 2e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
char s[maxn];
int nxt[maxn][26], par[maxn], len[maxn], tax[maxn], rk[maxn];
int sum[maxn * 40], ls[maxn * 40], rs[maxn * 40], rt[maxn];
int n, q, cnt, last, tot;
int add(int l){
++cnt; len[cnt] = l; return cnt;
}
void init(){
cnt = 0; last = add(0);
}
void insert(char ch){
int t = ch - 'a', p = last, cur;
last = cur = add(len[p] + 1);
while(p && !nxt[p][t]) nxt[p][t] = cur, p = par[p];
if(!p) { par[cur] = 1; return; }
int q = nxt[p][t];
if(len[q] == len[p] + 1) { par[cur] = q; return; }
int nq = add(len[p] + 1);
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[q] = par[cur] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
}
void update(int l, int r, int &rt, int pos, int val){
if(!rt) rt = ++tot;
if(l == r){
sum[rt] += val;
return;
}
int mid = gmid;
if(pos <= mid) update(l, mid, ls[rt], pos, val);
else update(mid + 1, r, rs[rt], pos, val);
sum[rt] = sum[ls[rt]] + sum[rs[rt]];
}
int merge(int l, int r, int x, int y){
if(!x || !y) return x + y;
int t = ++tot;
if(l == r){
sum[t] = sum[x] + sum[y];
return t;
}
int mid = gmid;
ls[t] = merge(l, mid, ls[x], ls[y]);
rs[t] = merge(mid + 1, r, rs[x], rs[y]);
sum[t] = sum[ls[t]] + sum[rs[t]];
return t;
}
int dfs(int l, int r, int rt){
if(l == r) return l;
int mid = gmid;
if(sum[rs[rt]]) return dfs(mid + 1, r, rs[rt]);
else return dfs(l, mid, ls[rt]);
}
int query(int l, int r, int rt, int L, int R){
if(!sum[rt]) return 0;
if(l >= L && r <= R){
if(!sum[rt]) return 0;
return dfs(l, r, rt);
}
int mid = gmid, ret;
if(R > mid) if(ret = query(mid + 1, r, rs[rt], L, R)) return ret;
if(L <= mid) if(ret = query(l, mid, ls[rt], L, R)) return ret;
return 0;
}
void rsort(){
for(int i = 1; i <= cnt; ++i) ++tax[len[i]];
for(int i = 1; i <= cnt; ++i) tax[i] += tax[i - 1];
for(int i = cnt; i >= 1; --i) rk[tax[len[i]]--] = i;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> s + 1;
n = strlen(s + 1);
init();
for(int i = 1; i <= n; ++i){
insert(s[i]);
update(1, n, rt[last], i, 1);
}
rsort();
for(int i = cnt; i >= 2; --i){
int u = rk[i];
rt[par[u]] = merge(1, n, rt[par[u]], rt[u]);
}
cin >> q;
while(q--){
int l, r; cin >> l >> r >> s + 1;
int ret = -1, ch = -1;
int p = 1, m = strlen(s + 1);
s[++m] = 'a' - 1;
for(int i = 1; i <= m && p; ++i){
int t = s[i] - 'a';
for(int j = t + 1; j < 26; ++j){
int v = nxt[p][j];
if(!v) continue;
int R = query(1, n, rt[v], 1, r);
int L = R - i + 1;
if(L < l) continue;
ret = i - 1, ch = j;
break;
}
p = nxt[p][t];
}
if(ret == -1) cout << "-1\n";
else{
s[++ret] = ch + 'a', s[++ret] = 0;
cout << s + 1 << "\n";
}
}
return 0;
}