题意
给你一个字符串, a ∼ z a\sim z a∼z每个字母有一个权值,找出权值第 k k k小的子串,若不存在输出-1
分析
第 k k k小,考虑二分,二分答案, c h e c k check check函数计算有多少个串比当前二分的值小。
如何计算比当前
m
i
d
mid
mid小的子串数量?
可以使用前缀和,对于当前起点
l
l
l,二分右端点
r
r
r,使得
v
a
l
l
∼
r
val_{l \sim r}
vall∼r值刚好小于等于
m
i
d
mid
mid(指
l
∼
r
l\sim r
l∼r这一段子串权值小于等于
m
i
d
mid
mid。
但按照上述方法二分会出现重复,如何去重?
后缀数组可以对一个字符串的后缀进行排序,并求出两个后缀的最长公共前缀,可以在计算时减去最长公共前缀
l
c
p
lcp
lcp。
想一想二分细节!
时间复杂度
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int ,int> pii;
#define endl '\n'
ll gcd(ll a, ll b){
return b == 0 ? a : gcd(b, a % b);
}
void input(){
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
}
const int N = 1e6+10, M = N * 2, inf = 1e8;
int n, m;
char s[N];
// x:第一关键字,y:第二关键字,c:关键字个数
int sa[N], x[N], y[N], c[N], rk[N], height[N];
void get_sa(){
m = 122;
for(int i = 1; i <= m; i++) c[i] = 0; // 多组清数组
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i-1];
for(int i = n; i; i--) sa[c[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1){
int num = 0;
for(int i = n-k+1; i <= n; i++) y[++num] = i;
for(int i = 1; i <= n; i++)
if(sa[i] > k) y[++num] = sa[i] - k;
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i-1];
for(int i = n; i; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i++)
x[sa[i]] = (y[sa[i]] == y[sa[i-1]] && y[sa[i]+k] == y[sa[i-1]+k]) ? num : ++num;
if(num == n) break;
m = num;
}
}
void get_height(){
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
for(int i = 1, k = 0; i <= n; i++){
if(rk[i] == 1) continue;
if(k) k--;
int j = sa[rk[i]-1];
while(i + k <= n && j + k <= n && s[i+k] == s[j+k]) k++;
height[rk[i]] = k;
}
}
int t, v[N], sum[N];
ll k;
bool check(int now){
ll res = 0;
for(int i = 1; i <= n; i++){
int id = upper_bound(sum + 1, sum + n + 1, now + sum[i-1]) - sum;
int nowsum = id - i;
nowsum = max(0, nowsum-height[rk[i]]);
res += nowsum;
}
return res >= k;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
// input();
cin>>t;
while(t--){
cin>>n>>k>>(s+1);
for(int i = 0; i <= 25; i++) cin>>v[i];
for(int i = 1; i <= n; i++) sum[i] = sum[i-1] + v[s[i] - 'a'];
get_sa();
get_height();
int l = 0, r = 1e7 + 10, mid; // 二分寻找k值
while(l < r){
mid = (l + r) >> 1;
if(check(mid)) r = mid;
else l = mid + 1;
}
if(r != 1e7 + 10) cout<<r<<endl;
else cout<<-1<<endl;
}
return 0;
}