Solution
反串的后缀自动机就是原串的后缀树,在后缀树上树形DP即可。
maxu
即是点
u
的深度,
#include <bits/stdc++.h>
using namespace std;
const int N = 1010101;
typedef long long ll;
inline char get(void) {
static char buf[100000], *S = buf, *T = buf;
if (S == T) {
T = (S = buf) + fread(buf, 1, 100000, stdin);
if (S == T) return EOF;
}
return *S++;
}
template<typename T>
inline void read(T &x) {
static char c; x = 0; int sgn = 0;
for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
if (sgn) x = -x;
}
inline void read(char *ch) {
char c; int len = 0;
for (c = get(); c < 'a' || c > 'z'; c = get());
for (; c >= 'a' && c <= 'z'; c = get()) ch[len++] = c;
ch[len] = 0;
}
ll ans;
int n;
char s[N];
struct SAM {
int root, last, Tcnt;
int mx[N], par[N];
int ri[N], f[N];
int buc[N], id[N];
int to[N][27];
int sta[N], top;
inline SAM(void) { root = last = Tcnt = 1; }
inline int Extend(int c) {
int p = last, np = ++Tcnt;
ri[np] = f[np] = 1;
mx[np] = mx[p] + 1;
for (; p && !to[p][c]; p = par[p])
to[p][c] = np;
if (p) {
int q = to[p][c];
if (mx[q] != mx[p] + 1) {
int nq = ++Tcnt;
mx[nq] = mx[p] + 1;
memcpy(to[nq], to[q], sizeof to[q]);
par[nq] = par[q];
par[q] = par[np] = nq;
for (; p && to[p][c] == q; p = par[p])
to[p][c] = nq;
} else {
par[np] = q;
}
} else {
par[np] = root;
}
return last = np;
}
inline void Insert(char *begin, char *end) {
for (char *c = begin; c != end; c++)
Extend(*c - 'a');
}
inline void Sort(void) {
for (int i = 1; i <= Tcnt; i++) ++buc[mx[i]];
for (int i = 1; i <= Tcnt; i++) buc[i] += buc[i - 1];
for (int i = Tcnt; i; i--) id[buc[mx[i]]--] = i;
}
inline void dp(void) {
for (int i = Tcnt; i; i--)
ri[par[id[i]]] += ri[id[i]];
for (int i = Tcnt; i; i--) {
int x = id[i];
ans -= 2ll * mx[par[x]] * ri[x] * f[par[x]];
f[par[x]] += ri[x];
}
}
inline void Debug(int u = 1) {
for (int i = 1; i <= top; i++) putchar(sta[i] + 'a');
putchar('\n');
for (int i = 0; i < 26; i++)
if (to[u][i]) {
sta[++top] = i;
Debug(to[u][i]);
--top;
}
}
};
SAM S;
int main(void) {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
read(s); n = strlen(s);
reverse(s, s + n);
S.Insert(s, s + n);
ans = (ll)(n - 1) * n * (n + 1) / 2;
S.Sort(); S.dp();// S.Debug();
cout << ans << endl;
return 0;
}