Analysis & Solution
f(S) f ( S ) 其实就是 S S 所有前缀出现次数减一的和,因为这里定义的深度 等于其非空border的数量,就意味着有一个子串与一个前缀相等,交换求和次序即前缀在非前缀子串中出现次数的和。(我TM连第一步转化都没想对,我转化成了 ∑r(|S|−r+1)∑lhl[r] ∑ r ( | S | − r + 1 ) ∑ l h l [ r ] )
这样 key(S) k e y ( S ) 就是考虑 S S 的一个子串出现的right集合 {r1,r2,…,rp} { r 1 , r 2 , … , r p } ,其中一对位置 (ri,rj) ( r i , r j ) 的贡献是 |S|−rj+1 | S | − r j + 1 。
现在考虑算出在末尾加一个字符带来的增量。对于right集合不变的结点,是对所有 (p2) ( p 2 ) 种方案长度都增加1(这里 p p 表示right当前集合的大小)。这用一个全局变量 记下 ∑(p2) ∑ ( p 2 ) 即可。如果right集合变大了(位于末点到根的链上),可先另 cur+=∑p++ c u r += ∑ p ++ 维护出 p p 增加后的 ,再使 ans+=cur a n s += c u r 。相当于先再结尾加一个点,其方案对应长度均为0,再令所有方案增加1的长度。当然,这里讨论的贡献还得乘上 len[i]−len[par[i]] l e n [ i ] − l e n [ p a r [ i ] ]
由于不需要在线,可以先把整个SAM建出来然后链剖,避免LCT。
Implementation
这里我们需要一个资辞区间加、区间和的线段树,用树状数组可以简便的实现。但是现在要带上 len−len[par] l e n − l e n [ p a r ] 的权值,需对模板作些修改。
ll A[N]; int B[N];
template <typename T>
void add(T (&tr)[N], int x, int w) {
do tr[x] += w;
while ((x += x&-x) <= tim);
}
template <typename T>
T sum(T (&tr)[N], int x) {
T s = 0;
do s += tr[x]; while (x ^= x&-x);
return s;
}
void update(int l, int r) {
add(A, l, l);
add(B, l, 1);
add(A, r+1, -r-1);
add(B, r+1, -1);
}
ll query(int l, int r) {
return (r+1)*sum(B, r) - sum(A, r)
- (l*sum(B, l-1) - sum(A, l-1));
}
原来是考虑区间加拆成一个在i处+1以后,对区间和拆成的一个j处的前缀和的贡献,等于
(j−i+1)[i≤j]
(
j
−
i
+
1
)
[
i
≤
j
]
。把这个贡献拆成
−i
−
i
和
j+1
j
+
1
两部分,分别用A
和B
维护。
现在贡献变成
(s[j]−s[i−1])[i≤j]
(
s
[
j
]
−
s
[
i
−
1
]
)
[
i
≤
j
]
,其中
s
s
是 的前缀和。也可以拆成两部分,只要把r+1
等地方改一下就可以了。
但要注意,应该把
len
l
e
n
的差分按dfs序重排后再前缀和。我一开始认为链剖处理的区间不会跨过树上不连续的链,所以可以直接传树上
len
l
e
n
进去。事实证明死得很惨,大概是因为在l-1
,r+1
的时候就会跑出去。
p.s. 我又把SAM写错了。。。少了一句 par[nq] = par[q];
Code
#include <vector>
#include <cstdio>
#include <cstring>
typedef long long ll;
const int S = 100032, N = S*2, MOD = 1000000007;
char s[S];
int go[N][26], par[N], len[N], cnt = 1, p = 1;
void extend(int c) {
int np = ++cnt;
len[np] = len[p] + 1;
for (; p && !go[p][c]; p = par[p]) go[p][c] = np;
if (p) {
int q = go[p][c];
if (len[q] != len[p] + 1) {
int nq = ++cnt;
len[nq] = len[p] + 1;
par[nq] = par[q];
memcpy(go[nq], go[q], sizeof *go);
for (; p && go[p][c] == q; p = par[p]) go[p][c] = nq;
par[q] = par[np] = nq;
}
else par[np] = q;
}
else par[np] = 1;
p = np;
}
std::vector<int> ch[N];
int son[N], top[N], dfn[N], tim;
ll _len[N];
int dfs1(int x) {
int sz = 1, mx = 0;
for (int u : ch[x]) {
int s = dfs1(u); sz += s;
if (mx < s) mx = s, son[x] = u;
}
return sz;
}
void dfs2(int x) {
dfn[x] = ++tim;
_len[tim] = len[x] - len[par[x]];
if (int v = son[x]) {
top[v] = top[x]; dfs2(v);
for (int u : ch[x])
if (u ^ v) dfs2(top[u] = u);
}
}
int B[N]; ll A[N], ans, cur;
template <typename T>
inline void add(T (&tr)[N], int x, T w) {
do tr[x] += w;
while ((x += x&-x) <= tim);
}
template <typename T>
inline T sum(T (&tr)[N], int x) {
T s = 0;
do s += tr[x]; while (x ^= x&-x);
return s;
}
void query(int l, int r) {
ll wl = _len[l-1], wr = _len[r];
cur += wr * sum(B, r) - sum(A, r)
- wl * sum(B, l-1) + sum(A, l-1);
add(A, l, wl); add(B, l, 1);
add(A, r+1, -wr); add(B, r+1, -1);
}
int main() {
int n;
scanf("%d%s",&n,s);
for (int i = 0; i < n; i++)
extend(s[i] -= 'a');
for (int i = 2; i <= cnt; i++)
ch[par[i]].push_back(i);
dfs1(1);
dfs2(top[1] = 1);
for (int i = 0; i < tim; i++)
_len[i+1] += _len[i];
for (int x = 1, i = 0; i < n;) {
int u = x = go[x][s[i++]];
do query(dfn[top[u]], dfn[u]);
while (u = par[top[u]]);
if ((ans += cur %= MOD) >= MOD) ans -= MOD;
printf("%lld\n",ans);
}
return 0;
}