Solution
又看了半天的SAM,感觉这次看陈老师的讲稿流畅多了。。
这道题的话,对
A
串建自动机,
B
匹配到状态
再考虑匹配到这个点的话,必然在
parent
树上它的祖先节点也是匹配的,又因为这些点的
maxv≤minu≤len
。所以
|rightv|∗(maxv−minv+1)
是能够全部计入方案的。
然后就是树形DP了。
#include <bits/stdc++.h>
using namespace std;
const int N = 402020;
typedef long long ll;
inline char get(void) {
static char buf[100000], *S = buf, *T = buf;
if (S == T) {
T = (S = buf) + fread(buf, 1, 100000, stdin);
if (S == T) return EOF;
}
return *S++;
}
template<typename T>
inline void read(T &x) {
static char c; x = 0; int sgn = 0;
for (c = get(); c < '0' || c > '9'; c = get()) if (c == '-') sgn = 1;
for (; c >= '0' && c <= '9'; c = get()) x = x * 10 + c - '0';
if (sgn) x = -x;
}
inline void read(char *ch) {
int len = 0; char c;
for (c = get(); c < 'a' || c > 'z'; c = get());
for (; c >= 'a' && c <= 'z'; c = get()) ch[len++] = c;
ch[len] = 0;
}
char a[N], b[N];
struct SAM {
int mx[N], par[N];
int ri[N], id[N], buc[N];
ll g[N], f[N];
int to[N][27];
int last, Tcnt, root;
int sta[N];
int top;
SAM(void) { Tcnt = root = last = 1; }
inline int Extend(int c) {
int p = last, np = ++Tcnt;
ri[np] = 1;
mx[np] = mx[p] + 1;
for (; p && !to[p][c]; p = par[p])
to[p][c] = np;
if (p) {
int q = to[p][c];
if (mx[q] != mx[p] + 1) {
int nq = ++Tcnt;
mx[nq] = mx[p] + 1;
memcpy(to[nq], to[q], sizeof to[q]);
par[nq] = par[q];
par[q] = par[np] = nq;
for (; p && to[p][c] == q; p = par[p])
to[p][c] = nq;
} else {
par[np] = q;
}
} else {
par[np] = root;
}
return last = np;
}
inline void Sort(void) {
for (int i = 1; i <= Tcnt; i++) ++buc[mx[i]];
for (int i = 1; i <= Tcnt; i++) buc[i] += buc[i - 1];
for (int i = Tcnt; i; i--) id[buc[mx[i]]--] = i;
}
inline void dp(void) {
for (int i = Tcnt; i; i--) ri[par[id[i]]] += ri[id[i]];
for (int i = 1; i <= Tcnt; i++) {
int pos = id[i];
g[pos] = g[par[pos]] + (ll)ri[pos] * (mx[pos] - mx[par[pos]]);
}
}
inline int Insert(char *begin, char *end) {
for (char* c = begin; c != end; c++)
Extend(*c - 'a');
}
inline ll Calc(char *begin, char *end) {
ll ans = 0; int p = root, len = 0;
for (char *c = begin; c != end; c++) {
int x = *c - 'a';
if (to[p][x]) {
++len; p = to[p][x];
} else {
while (p && !to[p][x]) p = par[p];
if (p) {
len = mx[p] + 1; p = to[p][x];
} else {
p = root; len = 0;
}
}
if (p != root) ans += g[par[p]] + (ll)(len - mx[par[p]]) * ri[p];
}
return ans;
}
inline void Debug(int u = 1) {
for (int i = 1; i <= top; i++) putchar(sta[i] + 'a');
putchar('\n');
for (int i = 0; i < 26; i++)
if (to[u][i]) {
sta[++top] = i;
Debug(to[u][i]);
--top;
}
}
};
SAM S;
int main(void) {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
read(a); read(b);
S.Insert(a, a + strlen(a));
S.Sort(); S.dp();// S.Debug();
cout << S.Calc(b, b + strlen(b)) << endl;
return 0;
}