题目:Problem - 6988 (dingbacode.com)
大意是给出一个字符串和每个字符的权重,求所有本质不同的子字符串中第k小的字符串的权重。字符串的权重是字符权重之和。
注意到必须要顾及所有子串的情况,才能从中选出第k小的答案。所以考虑用数据结构存储或者快速枚举所有情况以便得出结果。已知所有后缀的所有前缀就是所有的子串情况,所以我们枚举每一个后缀的所有前缀就有望得出结果。但是暴力枚举绝对超时,此时就需要数据结构来帮助处理。
首先明确一点,从所有情况中选出第k小的答案这种做法会超时。因为情况非常多,且每种情况之间的大小关系不明确,要得到所有排好序的子串大小关系非常困难,而且费时。注意到题目只要输出结果这唯一一个数字,且对于每一个后缀,它的前缀权重和长度满足单调关系,所以尝试直接二分枚举答案,再验证其正确性。
下面考虑如何二分,先确定下界是1,上界是原串总重。每次计算中值mid,并判断答案在mid前或后。若答案为mid时,子串权重不大于mid的数量少于k,说明mid小了;否则说明答案<=mid,并进一步缩小范围。
然后就是如何得出子串权重不大于mid的数量。对于每一个后缀,前缀越长权重越大,所以可以再用一个二分找到第一个权重大于mid的前缀,假设当前后缀的起始位置在i,这个后缀的第一个权重大于mid的前缀的末尾在j,那么这个后缀对当前枚举的答案所贡献的前缀数量就是(j-1)-i+1。但是注意到不同后缀的前缀可能相同,所以要知道不同后缀的公共前缀(LCP)来对答案去重。到了这里可以用后缀数组来处理了,按照后缀排名顺序枚举后缀,就能用height数组得知当前串与前一串的重复前缀,直接减去即可。
至于后缀数组是什么,可以参考这个->后缀数组详解 - Chrety - 博客园 (cnblogs.com)
讲的很好,居然能把我给看会了。
代码:
#include<iostream>
using namespace std;
typedef long long int ll;
ll ran[100005], sec[100005], sa[100005], t[100005], hei[100005];
ll val[26], sum[100005];
char s[100005];
void getsa(ll n)
{
ll num, i, len, cnt;
num =26;
for (i = 0;i <= 100000;i++)t[i] = 0;
for (i = 1;i <= n;i++)ran[i] = s[i - 1]-'a'+1, t[ran[i]]++;
for (i = 1;i <= num;i++)t[i] += t[i - 1];
for (i = n;i >= 1;i--)sa[t[ran[i]]--] = i;
for (len = 1;len <= n;len <<= 1)
{
cnt = 0;
for (i = n - len + 1;i <= n;i++)sec[++cnt] = i;
for (i = 1;i <= n;i++)if (sa[i] > len)sec[++cnt] = sa[i] - len;
for (i = 0;i <= num;i++)t[i] = 0;
for (i = 1;i <= n;i++)t[ran[i]]++;
for (i = 1;i <= num;i++)t[i] += t[i - 1];
for (i = n;i >= 1;i--)sa[t[ran[sec[i]]]--] = sec[i], sec[i] = ran[i];
ran[sa[1]] = 1;cnt = 1;
for (i = 2;i <= n;i++)
ran[sa[i]] = (sec[sa[i]] == sec[sa[i - 1]] && sec[sa[i] + len] == sec[sa[i - 1] + len]) ? cnt : ++cnt;
if (cnt == n)break;
num = cnt;
}
}
void geth(ll n)
{
ll i, j, k;
k = 0;
for (i = 1;i <= n;i++)
{
if (k)k--;
j = sa[ran[i] - 1];
while (i + k <= n && j + k <= n && s[i - 1 + k] == s[j - 1 + k])k++;
hei[ran[i]] = k;
}
}
ll check(ll x, ll n)
{
ll ans = 0, i;
ll l, r, mid, add;
for (i = 1;i <= n;i++)
{
l = sa[i];r = n + 1;add = -1;
while (l < r)
{
mid = (l + r) >> 1;
if (sum[mid] - sum[sa[i] - 1] <= x)l = mid + 1, add = mid;
else r = mid;
}
if (add != -1)
{
ans += add - sa[i] + 1;
ans -= add - sa[i] + 1 > hei[i] ? hei[i] : add - sa[i] + 1;
}
}
return ans;
}
int main()
{
ll tt, n, i, k, ans, l, r, mid;
cin >> tt;
while (tt--)
{
scanf("%lld%lld", &n, &k);
scanf("%s", s);
for (i = 0;i < 26;i++)scanf("%lld", &val[i]);
getsa(n);
geth(n);
sum[0] = 0;
for (i = 1;i <= n;i++)sum[i] = val[s[i - 1] - 'a'] + sum[i - 1];
l = 1;r = sum[n] + 1;ans = -1;
while (l < r)
{
mid = (l + r) >> 1;
if (check(mid, n) >= k)r = mid, ans = mid;
else l = mid + 1;
}
printf("%lld\n", ans);
}
}