题目描述:
给出字符串$S$,求长度小于$k$的子串中,有多少对相等。
解题思路:
其实就是求所有后缀的两两公共前缀与$k$取$min$后的和。我们先用$SA$构出$Height$数组,然后用上升的单调栈维护。如果当前的$Height$小于栈顶,就把栈顶的“清算”,清算多出的那部分,和能够沿伸的长度计算一下。要注意一些细节。
代码:
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 using namespace std; 5 6 const int N = 1e5 + 10, mo = 998244353; 7 int n, k, m, ans, sa[N], height[N], rank[N], tax[N], tp[N]; 8 char s[N]; 9 10 void rsort() { 11 for (int i = 1; i <= m; i ++) tax[i] = 0; 12 for (int i = 1; i <= n; i ++) tax[rank[tp[i]]] ++; 13 for (int i = 1; i <= m; i ++) tax[i] += tax[i - 1]; 14 for (int i = n; i >= 1; i --) sa[tax[rank[tp[i]]] --] = tp[i]; 15 } 16 17 int cmp(int *f, int x, int y, int w) {return f[x] == f[y] && f[x + w] == f[y + w];} 18 19 void SA() { 20 for (int i = 1; i <= n; i ++) rank[i] = s[i] - 'a' + 1, tp[i] = i; 21 m = 26, rsort(); 22 for (int w = 1, p = 1, i; p < n; w <<= 1, m = p) { 23 for (p = 0, i = n - w + 1; i <= n; i ++) tp[++ p] = i; 24 for (int i = 1; i <= n; i ++) if (sa[i] > w) tp[++ p] = sa[i] - w; 25 rsort(), swap(rank, tp), rank[sa[1]] = p = 1; 26 for (int i = 2; i <= n; i ++) rank[sa[i]] = cmp(tp, sa[i], sa[i - 1], w) ? p : ++ p; 27 } 28 int j, k = 0; 29 for (int i = 1; i <= n; height[rank[i ++]] = k) 30 for (k = k ? k - 1 : k, j = sa[rank[i] - 1]; s[i + k] == s[j + k]; k ++); 31 } 32 33 void work() { 34 int sta[N][2], top = 0, x; 35 sta[0][1] = 1; 36 for (int i = 2; i <= n + 1; i ++) { 37 height[x = i] = min(height[i], k); 38 while (top && height[i] < sta[top][0]) { 39 int l = i - sta[top][1] + 1, t = sta[top][0] - max(height[i], sta[top - 1][0]); 40 (ans += 1ll * t * (1ll * l * (l - 1) / 2 % mo) % mo) %= mo; 41 x = sta[top --][1]; 42 } 43 if (height[i] > sta[top][0]) sta[++ top][0] = height[i], sta[top][1] = x; 44 } 45 } 46 47 int main() { 48 scanf("%d %d", &n, &k); 49 scanf("%s", s + 1); 50 SA(); 51 work(); 52 printf("%d", ans); 53 return 0; 54 }