1600 Simple KMP
题意非常难理解,读懂题,画图后发现要求的就是每次增加一个字符后当前所有后缀的匹配个数
关于证明可以看程序下面
由于可以直接跳prt树,但是是
O
(
n
2
)
O(n^2)
O(n2)的,由于数据太水可以混过去
如果要动态做就是要将每次Extend的时候都更新prt树上一整条链上的所有点,i节点加上mxl[i]-mxl[pre[i]]
这个动态操作可以用LCT做,但感觉太麻烦写不出来.
于是就做成了离线的,先构造出整个字符串的prt树,做树链剖分,线段树上记一个当前区间的计数器,
记录访问次数,把mxl[i]-mxl[pre[i]]的和记录上传,于是就可以打标记了.
第一次f[i]+=f[i-1]+query相当于求出当前加入字符后对答案的贡献
第二次f[i]+=f[i-1]相当于做前缀和,统计整个答案.
#include <bits/stdc++.h>
#define ll long long
#define int ll
#define enter putchar('\n')
#define space putchar(' ')
#define pb(x) push_back(x)
#define jh(x, y) (x ^= y, y ^= x, x ^= y)
#define mid ((l + r) >> 1)
#define ls (p << 1)
#define rs (p << 1 | 1)
using namespace std;
template <class T> inline void read(T &x) {
x = 0; T f = 1; char ch = getchar();
while (!(ch >= '0' && ch <= '9')) { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
x *= f;
}
template <class T> inline int stread(T *x) {
int len = 0; char ch = getchar();
while (!(ch >= 'a' && ch <= 'z')) ch = getchar();
while (ch >= 'a' && ch <= 'z') x[len++] = ch, ch = getchar();
x[len] = 0; return len;
}
template <class T> inline void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x / 10) write(x / 10);
putchar('0' + x % 10);
}
template <class T> inline void Max(T &x, T y) { if (y > x) x = y; }
template <class T> inline void Min(T &x, T y) { if (y < x) x = y; }
const int MAX = 0x7fffffff;
const int MIN = 0x80000000;
const int INF = 0x3f3f3f3f;
//---------------------------------------------------------------------
const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
int ch[N][26], pre[N], mxl[N], last = 1, cnt = 1;
int n, ans, pos[N], prt[N], dep[N], bel[N], h[N], sz[N], son[N], dfn[N], dtot, f[N];
char s[N];
struct Edge { int b, nt; } e[N];
int head[N], e_num;
struct Seg { int sum, tag, val; } T[N << 2];
inline void anode(int u, int v) {
e[++e_num].b = v; e[e_num].nt = head[u]; head[u] = e_num;
}
inline int Extend(int c) {
int p = last, np = ++cnt; last = np;
mxl[np] = mxl[p] + 1;
for (; p && !ch[p][c]; p = pre[p]) ch[p][c] = np;
if (!p) { pre[np] = 1; return np; }
int q = ch[p][c], nq = ++cnt;
if (mxl[q] == mxl[p] + 1) { pre[np] = q; --cnt; return np; }
memcpy(ch[nq], ch[q], sizeof(ch[q]));
mxl[nq] = mxl[p] + 1, pre[nq] = pre[q], pre[q] = pre[np] = nq;
for (; ch[p][c] == q; p = pre[p]) ch[p][c] = nq;
return np;
}
inline void dfs1(int u) {
sz[u] = 1;
for (int i = head[u]; i; i = e[i].nt) {
int v = e[i].b;
if (v == prt[u]) continue;
prt[v] = u, dep[v] = dep[u] + 1;
dfs1(v);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
inline void dfs2(int u, int chain) {
bel[u] = chain;
dfn[u] = ++dtot; h[dtot] = u;
if (son[u]) dfs2(son[u], chain);
for (int i = head[u]; i; i = e[i].nt) {
int v = e[i].b;
if (v == prt[u] || v == son[u]) continue;
dfs2(v, v);
}
}
inline void pushup(int p) { T[p].sum = (T[ls].sum + T[rs].sum) % MOD; }
inline void build(int p, int l = 1, int r = cnt) {
if (l == r) {
T[p].val = mxl[h[l]] - mxl[pre[h[l]]];
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
T[p].val = (T[ls].val + T[rs].val) % MOD;
}
inline void pushdown(int p) {
if (T[p].tag > 0) {
int t = T[p].tag % MOD; T[p].tag = 0;
(T[ls].sum += T[ls].val * t) %= MOD;
(T[rs].sum += T[rs].val * t) %= MOD;
(T[ls].tag += t) %= MOD;
(T[rs].tag += t) %= MOD;
}
}
inline int query(int p, int L, int R, int l = 1, int r = cnt) {
if (L == l && R == r) { return T[p].sum % MOD; }
pushdown(p);
int ret = 0;
if (R <= mid) ret = query(ls, L, R, l, mid);
else if (L > mid) ret = query(rs, L, R, mid + 1, r);
else ret = query(ls, L, mid, l, mid) + query(rs, mid + 1, R, mid + 1, r);
pushup(p); return ret % MOD;
}
inline int query_chain(int x, int y) {
int ret = 0;
while (bel[x] != bel[y]) {
if (dep[bel[x]] < dep[bel[y]]) jh(x, y);
(ret += query(1, dfn[bel[x]], dfn[x])) %= MOD;
x = prt[bel[x]];
}
if (dep[x] < dep[y]) jh(x, y);
(ret += query(1, dfn[y], dfn[x])) %= MOD;
return ret;
}
inline void update(int p, int L, int R, int l = 1, int r = cnt) {
if (L == l && R == r) {
T[p].sum += T[p].val;
++T[p].tag;
return;
}
pushdown(p);
if (R <= mid) update(ls, L, R, l, mid);
else if (L > mid) update(rs, L, R, mid + 1, r);
else update(ls, L, mid, l, mid), update(rs, mid + 1, R, mid + 1, r);
pushup(p);
}
inline void update_chain(int x, int y) {
while (bel[x] != bel[y]) {
if (dep[bel[x]] < dep[bel[y]]) jh(x, y);
update(1, dfn[bel[x]], dfn[x]);
x = prt[bel[x]];
}
if (dep[x] < dep[y]) jh(x, y);
update(1, dfn[y], dfn[x]);
}
signed main() {
read(n); scanf("%s", s + 1);
for (int i = 1; i <= n; i++) pos[i] = Extend(s[i] - 'a');
for (int i = 1; i <= cnt; i++) anode(pre[i], i);
dfs1(1), dfs2(1, 1); build(1);
for (int i = 1; i <= n; i++) {
f[i] = (f[i - 1] + query_chain(1, pos[i])) % MOD;
update_chain(1, pos[i]);
}
for (int i = 1; i <= n; i++) {
(f[i] += f[i - 1]) %= MOD;
write(f[i]), enter;
}
return 0;
}
题解证明
我们先分析一个字符串的fail树的深度之和的性质
可以发现,一个串x的父亲是最大的前缀满足这个前缀等于后缀,那么我们可以推测:一个串x的祖先数量是x的前缀等于后缀的数量
那么答案就是
∑
i
=
1
n
∑
j
=
i
n
∑
k
=
1
j
−
i
+
1
[
S
[
i
.
.
.
i
+
k
−
1
]
=
=
S
[
j
−
k
+
1...
j
]
]
\sum_{i=1}^n \sum_{j=i}^n \sum_{k=1}^{j-i+1}[S[i...i+k-1]==S[j-k+1...j]]
∑i=1n∑j=in∑k=1j−i+1[S[i...i+k−1]==S[j−k+1...j]]
我们直接考虑两个子串S[x…y],S[l…r]的贡献,其中l>x,S[x…y]==S[l…r]
显然贡献就是有多少个子串以S[x…r]为后缀,贡献就是n-r+1
我们可以求出f[i]表示串S[1…i]有几对相等的子串,然后答案就是
∑
i
=
1
n
f
[
i
]
\sum_{i=1}^nf[i]
∑i=1nf[i]
那么问题来了,怎么求呢
如果只是要求有几对相等的子串的话,可以建出后缀树,对于一个点x,他的贡献是
C
S
(
x
)
2
∗
l
e
n
(
x
)
C_{S(x)}^2*len(x)
CS(x)2∗len(x),其中S(x)是x的子树中的后缀数量,len(x)是他到父亲的边的长度
我们只要用后缀自动机维护后缀树,再用Link Cut Tree 维护后缀树的答案即可
具体就是,我们要实现将x的父亲换为y,可以求出S(x),然后相当于链上加一个数d,询问链上的平方和,可以用二项式展开维护下
于是问题就完美解决了
时间复杂度O(nlogn)