网上的题解都是后缀自动机……这里讲一下后缀数组的做法。
求出后缀数组的
height
数组之后,任意两个不相同的后缀的lcp,都是
height
数组中一段区间的最小值。
在给出的式子中,
len(Ti)+len(Tj)
以及系数
2
可以提到外面去。所以现在要求出的就是
把求lcp转化成求
height
的区间最小值:
∑2≤i≤j≤nminjk=iheight[k]
。
把求区间最小值之和换一个定义,改为求
height
中的每个位置是多少个区间的最小值。为了避免重复,这里的最小值是指数值为第一关键字,位置为第二关键字。
预处理出两个数组:
1、
pre[i]
:满足
j<i
且
height[j]≤height[i]
的最大的
j
,如果没有则为
2、
suf[i]
:满足
j>i
且
height[j]<height[i]
的最小的
j
,如果没有则为
这两个数组可以用一个单调的栈扫描求得。
这样就可以得出,
height
的第
i
个位置是
32∑ni=2i(i−1)−2∑ni=2height[i](i−pre[i])(suf[i]−i)
。
代码:
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 5e5 + 5;
int n, sa[N], rank[N], height[N], w[N], stk[N], top, pre[N], suf[N];
char s[N];
void buildSA() {
int i, k, m = 26; int *x = rank, *y = height;
for (i = 1; i <= n; i++) w[x[i] = s[i] - 'a' + 1]++;
for (i = 2; i <= m; i++) w[i] += w[i - 1];
for (i = 1; i <= n; i++) sa[w[x[i]]--] = i;
for (k = 1; k < n; k <<= 1, swap(x, y)) {
int tt = 0; for (i = n - k + 1; i <= n; i++) y[++tt] = i;
for (i = 1; i <= n; i++) if (sa[i] > k) y[++tt] = sa[i] - k;
memset(w, 0, sizeof(w));
for (i = 1; i <= n; i++) w[x[i]]++;
for (i = 2; i <= m; i++) w[i] += w[i - 1];
for (i = n; i; i--) sa[w[x[y[i]]]--] = y[i];
m = 0; for (i = 1; i <= n; i++) {
int u = sa[i], v = sa[i - 1];
y[u] = x[u] != x[v] || x[u + k] != x[v + k] ? ++m : m;
}
if (m == n) break;
}
if (y != rank) copy(y, y + n + 1, rank);
height[1] = 0; for (i = 1, k = 0; i <= n; i++) {
if (k) k--; int u = sa[rank[i] - 1];
while (s[i + k] == s[u + k]) k++;
height[rank[i]] = k;
}
}
int main() {
int i; scanf("%s", s + 1); n = strlen(s + 1); buildSA();
stk[top = 0] = 1; for (i = 2; i <= n; i++) {
while (top && height[stk[top]] > height[i]) top--;
pre[stk[++top] = i] = stk[top - 1];
}
stk[top = 0] = n + 1; for (i = n; i >= 2; i--) {
while (top && height[stk[top]] >= height[i]) top--;
suf[stk[++top] = i] = stk[top - 1];
}
ll ans = 0; for (i = 2; i <= n; i++)
ans += 1ll * height[i] * (i - pre[i]) * (suf[i] - i);
ll tmp = 0; for (i = 2; i <= n; i++)
tmp += 3ll * i * (i - 1) >> 1;
cout << tmp - ans * 2 << endl;
return 0;
}