题面
解法
重新开始学习后缀数组的第一道题……
- 显然,我们只要考虑怎么求 ∑ i ∑ j l c p ( s u f ( i ) , s u f ( j ) ) \sum_i\sum_j lcp(suf(i),suf(j)) ∑i∑jlcp(suf(i),suf(j))。
- 可以将问题转化成求排完序之后两两的 l c p lcp lcp之和。
- 因为我们知道,排完序之后 l c p ( s u f ( s a [ l ] ) , s u f ( s a [ r ] ) ) = m i n ( h t [ l ] , … , h t [ r ] ) lcp(suf(sa[l]),suf(sa[r]))=min(ht[l],\dots,ht[r]) lcp(suf(sa[l]),suf(sa[r]))=min(ht[l],…,ht[r])。那么我们可以对于每一个 h t [ i ] ht[i] ht[i]分别考虑它对答案产生的贡献是什么。求出 h t [ i ] ht[i] ht[i]作为最小值的时候所有可能的区间数就可以了,这个可以直接用单调栈来实现,这个部分的复杂度为 O ( n ) O(n) O(n)。
- 时间复杂度: O ( n log n ) O(n\log n) O(nlogn)
【注意事项】
- 因为 h t [ i ] ht[i] ht[i]作为最小值的所有区间中可能会存在有一个值等于 h t [ i ] ht[i] ht[i]的情况,这种情况我们会重复计算。那么我们可以强制使左端点都 ≥ h t [ i ] ≥ht[i] ≥ht[i],右端点都 > h t [ i ] >ht[i] >ht[i],这样就能避免出现这个问题了
代码
#include <bits/stdc++.h>
#define N 500010
using namespace std;
char st[N];
struct SuffixArray {
int n, m, l[N], r[N], s[N], y[N], ht[N], sa[N], rnk[N];
void Sort() {
for (int i = 1; i <= m; i++) s[i] = 0;
for (int i = 1; i <= n; i++) s[rnk[i]]++;
for (int i = 1; i <= m; i++) s[i] += s[i - 1];
for (int i = n; i; i--) sa[s[rnk[y[i]]]--] = y[i];
}
void build() {
n = strlen(st + 1), m = 130;
for (int i = 1; i <= n; i++) rnk[i] = st[i], y[i] = i;
Sort(); int len = 0;
for (int k = 1; k <= n; k <<= 1, m = len, len = 0) {
for (int i = n - k + 1; i <= n; i++) y[++len] = i;
for (int i = 1; i <= n; i++) if (sa[i] > k) y[++len] = sa[i] - k;
Sort(), swap(rnk, y), rnk[sa[1]] = len = 1;
for (int i = 2; i <= n; i++)
rnk[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? len : ++len;
if (len >= n) break;
}
for (int i = 1, k = 0; i <= n; i++) {
if (rnk[i] == 1) continue;
if (k) k--; int j = sa[rnk[i] - 1];
while (st[i + k] == st[j + k]) k++;
ht[rnk[i]] = k;
}
}
long long solve() {
stack <int> st;
for (int i = 1; i <= n; i++) {
while (!st.empty() && ht[i] <= ht[st.top()]) st.pop();
if (st.empty()) l[i] = 1; else l[i] = st.top() + 1;
st.push(i);
}
while (!st.empty()) st.pop();
for (int i = n; i; i--) {
while (!st.empty() && ht[i] < ht[st.top()]) st.pop();
if (st.empty()) r[i] = n; else r[i] = st.top() - 1;
st.push(i);
}
for (int i = 1; i <= n; i++) printf("[%d, %d] : ht[%d] = %d\n", l[i], r[i], i, ht[i]);
long long ret = 0;
for (int i = 1; i <= n; i++) ret += 2ll * (i - l[i] + 1) * (r[i] - i + 1) * ht[i];
return ret;
}
} SA;
int main() {
cin >> st + 1; SA.build();
int n = strlen(st + 1);
long long ans = 1ll * (n - 1) * n * (n + 1) / 2 - SA.solve();
cout << ans << "\n";
return 0;
}