洛谷传送门
BZOJ传送门
题目描述
给定一个长度为 n n n 的字符串 S S S,令 T i T_i Ti 表示它从第 i i i 个字符开始的后缀。求
∑ 1 ⩽ i < j ⩽ n len ( T i ) + len ( T j ) − 2 × lcp ( T i , T j ) \displaystyle \sum_{1\leqslant i<j\leqslant n}\text{len}(T_i)+\text{len}(T_j)-2\times\text{lcp}(T_i,T_j) 1⩽i<j⩽n∑len(Ti)+len(Tj)−2×lcp(Ti,Tj)
其中, len ( a ) \text{len}(a) len(a) 表示字符串 a a a 的长度, lcp ( a , b ) \text{lcp}(a,b) lcp(a,b) 表示字符串 a a a 和字符串 b b b 的最长公共前缀。
输入输出格式
输入格式:
一行,一个字符串 S S S。
输出格式:
一行,一个整数,表示所求值。
输入输出样例
输入样例#1:
cacao
输出样例#1:
54
说明
对于 100% 的数据,保证 2 ⩽ n ⩽ 500000 2\leqslant n\leqslant 500000 2⩽n⩽500000,且均为小写字母。
解题分析
题目求的是后缀的公共前缀, 我们把串倒过来, 就变成了求前缀的最长公共后缀。
在 S A M SAM SAM里每个点都包含了前缀, 而显然其公共后缀就是在 p a r e n t parent parent树上面的LCA。
所以我们直接在 p a r e n t parent parent树上从下往上 D P DP DP就好了。
至于这个玩意
∑
1
⩽
i
<
j
⩽
n
len
(
T
i
)
+
len
(
T
j
)
\displaystyle \sum_{1\leqslant i<j\leqslant n}\text{len}(T_i)+\text{len}(T_j)
1⩽i<j⩽n∑len(Ti)+len(Tj)
显然就等于
(
l
e
n
−
1
)
l
e
n
(
l
e
n
+
1
)
2
\frac{(len-1)len(len+1)}{2}
2(len−1)len(len+1)。
代码如下:
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <cctype>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define MX 1000050
#define ll long long
int l, cnt, cur, last;
ll ans, num[MX];
int to[MX][26], len[MX], par[MX], siz[MX], buc[MX], ord[MX];
char dat[MX];
namespace SAM
{
IN void insert(R int id)
{
R int now, tar;
cur = ++cnt; len[cur] = len[last] + 1; siz[cur] = 1;
for (now = last; ~now; now = par[now])
{
if(to[now][id]) break;
to[now][id] = cur;
}
if(now < 0) return last = cur, par[cur] = 0, void();
tar = to[now][id];
if(len[tar] == len[now] + 1) return last = cur, par[cur] = tar, void();
int nw = ++cnt; std::memcpy(to[nw], to[tar], sizeof(to[tar]));
par[nw] = par[tar], par[tar] = nw; len[nw] = len[now] + 1;
for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
par[cur] = nw; last = cur;
}
IN ll calc()
{
R int now; ll ret = 0;
for (R int i = 1; i <= cnt; ++i) ++buc[len[i]];
for (R int i = 1; i <= l; ++i) buc[i] += buc[i - 1];
for (R int i = 1; i <= cnt; ++i) ord[buc[len[i]]--] = i;
for (R int i = cnt; i; --i)
{
now = ord[i];
if(par[now] > 0)
ret += 1ll * siz[now] * siz[par[now]] * len[par[now]], siz[par[now]] += siz[now];
}
return ret * 2;
}
}
int main(void)
{
scanf("%s", dat + 1); l = std::strlen(dat + 1);
ans = 1ll * (l - 1) * l / 2 * (l + 1); std::reverse(dat + 1, dat + 1 + l);
par[0] = -1;
for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a');
printf("%lld", ans - SAM::calc());
}