题目链接
题意
问一个字符串有多少个子串出现恰好 k 次
思路
求出
注意 k=1 的情况要特判,为 其本身的长度减去与前与后的公共前缀的最大值。
Code
#include <bits/stdc++.h>
#define maxn 100010
using namespace std;
typedef long long LL;
char s[maxn];
int a[maxn], x[maxn], wa[maxn], wb[maxn], sw[maxn], sv[maxn], h[maxn], rk[maxn], len[maxn], sa[maxn], mn[maxn][26], n;
bool cmp(int* r, int i, int j, int l) { return r[i] == r[j] && r[i+l] == r[j+l]; }
void init(int* r, int* sa, int n, int m) {
int* x = wa, *y = wb, *t, i, j, k, p;
for (i = 0; i < m; ++i) sw[i] = 0;
for (i = 0; i < n; ++i) ++sw[x[i] = r[i]];
for (i = 1; i < m; ++i) sw[i] += sw[i-1];
for (i = n-1; i >= 0; --i) sa[--sw[x[i]]] = i;
for (j = 1, p = 1; p < n; j <<= 1, m = p) {
for (i = n-j, p = 0; i < n; ++i) y[p++] = i;
for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
for (i = 0; i < n; ++i) sv[i] = x[y[i]];
for (i = 0; i < m; ++i) sw[i] = 0;
for (i = 0; i < n; ++i) ++sw[sv[i]];
for (i = 1; i < m; ++i) sw[i] += sw[i-1];
for (i = n-1; i >= 0; --i) sa[--sw[sv[i]]] = y[i];
t = x, x = y, y = t, x[sa[0]] = 0;
for (i = 1, p = 1; i < n; ++i) x[sa[i]] = cmp(y, sa[i], sa[i-1], j) ? p-1 : p++;
}
for (i = 0; i < n; ++i) rk[sa[i]] = i;
k = 0;
for (i = 0; i < n-1; h[rk[i++]] = k) {
for (k = k ? k-1 : 0, j = sa[rk[i]-1]; r[i+k] == r[j+k]; ++k);
}
for (i = 1; i < n; ++i) len[i] = n - 1 - sa[i];
h[n] = len[n] = 0;
}
void rmqInit() {
for (int i = 1; i <= n; ++i) mn[i][0] = h[i];
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << (j-1)) - 1 <= n; ++i) {
mn[i][j] = min(mn[i][j-1], mn[i+(1 << (j-1))][j-1]);
}
}
}
int query(int l, int r) {
int k = (int)(log((double)r-l+1)/log((double)2));
return min(mn[l][k], mn[r-(1<<k)+1][k]);
}
void work() {
int k;
scanf("%d%s", &k, s);
n = strlen(s); int m = 0;
for (int i = 0; i < n; ++i) a[i] = s[i], m = max(m, a[i]);
a[n++] = 0;
init(a, sa, n, ++m);
LL ans = 0;
if (k == 1) {
for (int i = 1; i < n; ++i) ans += (LL)(len[i] - max(h[i], h[i+1]));
}
else {
rmqInit();
for (int i = 1; i <= n-k; ++i) {
int j = i+k-1;
int minn = query(i+1, j), maxx = max(h[i], h[j+1]);
if (minn > maxx) ans += (LL)minn - maxx;
}
}
printf("%lld\n", ans);
}
int main() {
freopen("in.txt", "r", stdin);
int T;
scanf("%d", &T);
while (T--) work();
return 0;
}