简略题意:
B串的所有子串在A串中的出现次数之和是多少。
首先考虑A求A串自身的子串在A中出现次数之和是多少。
考虑
parent
p
a
r
e
n
t
树,树上节点数为
cnt
c
n
t
,那么答案就是:
∑cnti=1(maxlen[i]−minlen[i]+1)∗right[i]
∑
i
=
1
c
n
t
(
m
a
x
l
e
n
[
i
]
−
m
i
n
l
e
n
[
i
]
+
1
)
∗
r
i
g
h
t
[
i
]
。
minlen[i]=maxlen[fa[i]]
m
i
n
l
e
n
[
i
]
=
m
a
x
l
e
n
[
f
a
[
i
]
]
maxlen[i]
m
a
x
l
e
n
[
i
]
在建树的时候得到
right[i]
r
i
g
h
t
[
i
]
可以通过拓扑排序后求得
for(int i = 1; i <= cnt; i++)
for(int j = 0; j < 26; j++) {
if(trans[i][j]) deg[trans[i][j]]++;
}
queue<int> q;
q.push(1);
while(!q.empty()) {
int u = q.front(); q.pop();
G.push_back(u);
for(int i = 0; i < 26; i++)
if(trans[u][i]) {
--deg[trans[u][i]];
if(deg[trans[u][i]] == 0)
q.push(trans[u][i]);
}
}
int p = 1;
for(int i = 0; i < len; i++) {
p = trans[p][str[i]-'a'];
right[p]++;
}
for(int i = G.size()-1; i >= 0; i--) {
int val = G[i];
right[fa[val]] += right[val];
}
现在需要求B在A上的该答案,那么让B在A上走,然后每到一个节点累加答案即可,需要注意动态调节一下B匹配的长度,以及从每一个节点暴力往上跳会超时,所以可以记忆化一下。
#define others
#ifdef poj
#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#include <map>
#include <set>
#endif // poj
#ifdef others
#include <bits/stdc++.h>
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
#define eps 1e-8
const double pi = acos(-1.0);
typedef long long LL;
typedef unsigned long long ULL;
void umax(int &a, int b) {
a = max(a, b);
}
void umin(int &a, int b) {
a = min(a, b);
}
int dcmp(double x) {
return fabs(x) <= eps?0:(x > 0?1:-1);
}
void file() {
freopen("data_in.txt", "r", stdin);
freopen("data_out.txt", "w", stdout);
}
namespace solver {
namespace SAM{
static const int maxnode = 2e6+10;//至少开两倍
static const int maxn = 26;
int p, q, np, nq;
int cnt, ST_T, len;
int trans[maxnode][maxn], l[maxnode], fa[maxnode];
int newnode() {
int x = ++cnt;
memset(trans[x], 0, sizeof trans[x]);
fa[x] = 0;
l[x] = 0;
return cnt;
}
void init() {
cnt = 0;
ST_T = newnode();
}
void add(int c) {
//令当前串为T,新加的字符为x。
p = ST_T; np = ST_T = newnode(); l[np] = l[p] + 1;//令p = ST(T),新建np = ST(Tx)
while(!trans[p][c]&&p) trans[p][c] = np, p = fa[p];//对于p的所有没有标号c的边的祖先v,trans[v][c] = np。
if(!p) fa[np] = 1; //找到p的第一个祖先vp,他有标号c的边,如果没有这样的vp,那么fa[p]=root,结束该阶段。
else {
q = trans[p][c];//令q=trans[vp][c]
if(l[p] + 1 == l[q]) fa[np] = q;//若l[q] = l[vp] + 1,令fa[np] = q,结束该阶段。
else {
nq = newnode(); l[nq] = l[p] + 1;//否则建立新节点nq
memcpy(trans[nq],trans[q],sizeof(trans[q]));//trans(nq, *) = trans(q, *)
fa[nq] = fa[q];
fa[np] = fa[q] = nq;
while(trans[p][c] == q) trans[p][c] = nq, p = fa[p];//对于所有的trans(v, c) == q的p的祖先v, trans(v, c)改为nq。
}
}
}
int deg[maxnode], right[maxnode];
vector<int>G;
void build(char *str) {
memset(right, 0, sizeof right);
memset(deg, 0, sizeof deg);
len = strlen(str);
memset(right, 0, sizeof right);
memset(deg, 0, sizeof deg);
for(int i = 0; i < len; i++)
add(str[i]-'a');
for(int i = 1; i <= cnt; i++)
for(int j = 0; j < 26; j++) {
if(trans[i][j]) deg[trans[i][j]]++;
}
queue<int> q;
q.push(1);
while(!q.empty()) {
int u = q.front(); q.pop();
G.push_back(u);
for(int i = 0; i < 26; i++)
if(trans[u][i]) {
--deg[trans[u][i]];
if(deg[trans[u][i]] == 0)
q.push(trans[u][i]);
}
}
int p = 1;
for(int i = 0; i < len; i++) {
p = trans[p][str[i]-'a'];
right[p]++;
}
for(int i = G.size()-1; i >= 0; i--) {
int val = G[i];
right[fa[val]] += right[val];
}
}
LL dp[maxnode];
LL dfs(int x) {
if(x == 0) return 0;
if(~dp[x]) return dp[x];
return dp[x] = (l[x] - l[fa[x]]) * right[x] + dfs(fa[x]);
}
LL solve(char *str) {
// cout<<fa[1];
memset(dp, -1, sizeof dp);
LL res = 0;
int p = 1, pp;
len = strlen(str);
int x = 0;
for(int i = 0; i < len; i++) {
if(trans[p][str[i]-'a']) {
p = trans[p][str[i]-'a'];
x++;
} else {
while(p && !trans[p][str[i]-'a'])
p = fa[p];
if(p == 0)
x = 0, p = 1;
else {
x = l[p]+1;
p = trans[p][str[i]-'a'];
}
}
res += (x - l[fa[p]]) * right[p];
res += dfs(fa[p]);
}
return res;
}
};
char s1[550000], s2[550000], s3[550000];
void solve() {
scanf("%s", s1);
scanf("%s", s2);
SAM::init();
SAM::build(s1);
cout<<SAM::solve(s2);
}
}
int main() {
// file();
solver::solve();
return 0;
}