【思路要点】
- 注意到字符串的生成方式近乎随机,不妨认为对于一个后缀 S [ i . . . N ] S[i...N] S[i...N] 与其余后缀的 L C P LCP LCP 的最大值 k i k_i ki 不会很大,事实上,当字符串随机时, k i k_i ki 期望为 O ( L o g N ) O(LogN) O(LogN) 级别。
- 用后缀树求出 k i k_i ki ,那么 S [ i . . . i ] , S [ i . . . i + 1 ] , . . . , S [ i . . . i + k i − 1 ] S[i...i],S[i...i+1],...,S[i...i+k_i-1] S[i...i],S[i...i+1],...,S[i...i+ki−1] 在串中出现了至少两次,其余 S [ i . . . N ] S[i...N] S[i...N] 的前缀在串中仅出现一次,不会作为 b o r d e r border border 出现。
- 枚举所有可能的 b o r d e r border border ,假设当前枚举到 S S S ,其最长的 b o r d e r border border 为 S ′ S' S′ ,在串中出现次数为 c n t cnt cnt ,则它对答案的贡献为 c n t ∗ ( ∣ S ∣ − ∣ S ′ ∣ ) cnt*(|S|-|S'|) cnt∗(∣S∣−∣S′∣) 。
- 用字典树实现,时间复杂度 O ( ∑ k i ) = O ( N L o g N ) O(\sum k_i)=O(NLogN) O(∑ki)=O(NLogN) 。
- 但事实上字符串并不一定是完全随机的。
考虑数列 x 0 = s e e d , x i = 13331 x i − 1 + 23333 ( i ≥ 1 ) x_0=seed,x_i=13331x_{i-1}+23333\ (i≥1) x0=seed,xi=13331xi−1+23333 (i≥1) ,它应当存在通项公式 x i = 1333 1 i ( s e e d + c ) − c ( i ≥ 1 ) x_i=13331^i(seed+c)-c\ (i≥1) xi=13331i(seed+c)−c (i≥1) 。注意到 13331 13331 13331 是 1 0 9 + 7 10^9+7 109+7 的原根,当 s e e d ≠ − c seed\ne -c seed̸=−c 时,数列的循环节长度为 1 0 9 + 6 10^9+6 109+6 ,字符串可以认为是随机的,而当 s e e d = − c seed=-c seed=−c 时, 数列的循环节长度为 1 1 1 ,因此整个字符串都是同一个字符,需要特判一下。- − c ≡ 791372847 ( M o d 1 0 9 + 7 ) -c\equiv 791372847\ (Mod\ 10^9+7) −c≡791372847 (Mod 109+7) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 1e6 + 5; const int MAXP = 4e7 + 5; template <typename T> void read(T &x) { x = 0; int f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0'; x *= f; } struct SuffixTree { int root, size, last; int child[MAXN * 2][2], home[MAXN * 2]; int fail[MAXN * 2], depth[MAXN * 2], tsize[MAXN * 2]; int newnode(int dep) { fail[size] = 0; depth[size] = dep; memset(child[size], 0, sizeof(child[size])); return size++; } void extend(int ch, int from) { int p = last, np = newnode(depth[last] + 1); while (child[p][ch] == 0) { child[p][ch] = np; p = fail[p]; } if (child[p][ch] == np) fail[np] = root; else { int q = child[p][ch]; if (depth[q] == depth[p] + 1) fail[np] = q; else { int nq = newnode(depth[p] + 1); fail[nq] = fail[q]; fail[q] = fail[np] = nq; memcpy(child[nq], child[q], sizeof(child[q])); while (child[p][ch] == q) { child[p][ch] = nq; p = fail[p]; } } } home[last = np] = from; } void init(int *s, int len, int *res) { size = 0; root = last = newnode(0); for (int i = len; i >= 1; i--) extend(s[i], i); for (int i = 1; i < size; i++) tsize[fail[i]]++; for (int i = 1; i < size; i++) if (home[i]) { if (tsize[i]) res[home[i]] = depth[i]; else res[home[i]] = depth[fail[i]]; } } } ST; int n, seed, a[MAXN], nxt[MAXN], k[MAXN]; int root, size, cnt[MAXP], child[MAXP][2]; void calcnxt(int *s, int n) { nxt[1] = 0; for (int i = 2; i <= n; i++) { int j = nxt[i - 1]; while (j != 0 && s[j + 1] != s[i]) j = nxt[j]; if (s[j + 1] == s[i]) nxt[i] = j + 1; else nxt[i] = j; } } map <int, int> mp; void solve(int n, int seed) { int from = 1, period = n; for (int i = 1; i <= n; i++) { seed = (1ll * seed * 13331 + 23333) % 1000000007; if (mp[seed]) { from = mp[seed], period = i - mp[seed]; break; } else mp[seed] = i; } assert(period <= 100); long long ans = 0; for (int i = 1; i < from; i++) { calcnxt(a + i - 1, n - i + 1); for (int j = 1; j <= n - i + 1; j++) ans += nxt[j]; } n = n - from + 1; static long long sum[MAXN]; for (int i = 1; i <= period; i++) { calcnxt(a + from + i - 2, n - i + 1); for (int j = 1; j <= n - i + 1; j++) sum[j] = sum[j - 1] + nxt[j]; for (int j = n - i + 1; j >= 1; j -= period) ans += sum[j]; } cout << ans << endl; } int main() { read(n), read(seed); int bak = seed; for (int i = 1; i <= n; i++) { seed = (1ll * seed * 13331 + 23333) % 1000000007; a[i] = seed & 1; } ST.init(a, n, k); long long sum = 0; for (int i = 1; i <= n; i++) sum += k[i]; if (sum > 1e8) { solve(n, bak); return 0; } for (int i = 1; i <= n; i++) { int now = root; for (int j = 1, pos = i; j <= k[i] && pos <= n; j++, pos++) { if (child[now][a[pos]] == 0) child[now][a[pos]] = ++size; now = child[now][a[pos]], cnt[now]++; } } long long ans = 0; for (int i = 1; i <= n; i++) { int now = root, tlen = min(k[i], n - i + 1); calcnxt(a + i - 1, tlen); for (int j = 1, pos = i; j <= k[i] && pos <= n; j++, pos++) { now = child[now][a[pos]], cnt[now]--; ans += 1ll * cnt[now] * (j - nxt[j]); } } cout << ans << endl; return 0; }