题意:求字符串所有回文子串相交对数。
题解:回文自动机
要求相交对数,那么先求出对数的总数,然后减去不相交的对数就好了。
s
u
m
[
]
sum[]
sum[]:从
i
i
i开始往后的回文子串个数。从后往前更新。
a
d
d
(
)
add()
add()返回值:以
i
i
i结尾的回文子串个数。
总对数
−
Σ
(
s
u
m
(
)
∗
a
d
d
(
)
)
- Σ(sum()*add())
−Σ(sum()∗add()) 就是答案了。
然后就是
n
e
x
t
next
next数组会超空间,用邻接表。
这题也可以用
M
a
n
a
c
h
e
r
Manacher
Manacher做,用
M
p
[
]
Mp[]
Mp[]更新就可以了。
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<fstream>
#include<set>
#include<map>
#define ll long long
using namespace std;
const int MAXN = 2e6 + 5;
const int mod = 51123987;
int n;
char str[MAXN];
int sum[MAXN];
struct link { //next的邻接表
int u[MAXN]; int v[MAXN];
int next[MAXN]; int head[MAXN];
int tot;
void clear() {
memset(head, -1, sizeof(head));
tot = 0;
}
void clear(int x) { head[x] = -1; }
int get(int x, int y) {
for (int i = head[x]; i != -1; i = next[i]) {
if (u[i] == y)
return v[i];
}
return 0;
}
void insert(int x, int y, int z) {
u[tot] = y; v[tot] = z;
next[tot] = head[x];
head[x] = tot++;
}
};
struct Tree {
//int next[MAX+5][26];
int fail[MAXN];//指向当前节点回文串中最长的后缀回文子串
int cnt[MAXN];//当前节点的回文串一共有多少个
int num[MAXN];//当前节点为结尾的回文子串的个数
int len[MAXN];//当前节点的回文串的长度
int s[MAXN];//字符串
int last;//回文树最后一个节点
int n;
int p;//回文树当前有多少个节点
link next;//邻接表,表示当前节点回文串在两端添加一个字符形成的另一个节点的回文串
int new_node(int x) {
//memset(next[p],0,sizeof(next[p]));
cnt[p] = 0;
next.clear(p);
num[p] = 0;
len[p] = x;
return p++;
}
void init() {
next.clear();
p = 0;
new_node(0);
new_node(-1);
last = 0;
n = 0;
s[0] = -1;
fail[0] = 1;
}
int get_fail(int x) {
while (s[n - len[x] - 1] != s[n])
x = fail[x];
return x;
}
int add(int x) { //返回以当前下标为末尾的回文串数量
x -= 'a';
s[++n] = x;
int cur = get_fail(last);
if (!(last = next.get(cur, x))) {
int now = new_node(len[cur] + 2);
fail[now] = next.get(get_fail(fail[cur]), x);
next.insert(cur, x, now);
num[now] = num[fail[now]] + 1;
last = now;
}
cnt[last]++;
return num[last];
}
ll Allsum() { //求出所有回文子串的数目
ll ret = 0;
for (int i = p - 1; i > 0; i--) {
cnt[fail[i]] = (cnt[fail[i]] + cnt[i]) % mod;
ret = (ret + cnt[i]) % mod;
}
return ret;
}
}tree;
int main() {
scanf("%d", &n);
scanf("%s", str);
tree.init();
sum[n] = 0;
for (int i = n - 1; i >= 0; i--) sum[i] = (sum[i + 1] + tree.add(str[i])) % mod;
tree.init();
ll ans = 0, res = 0;
for (int i = 0; i <= n - 1; i++) {
ans += (1ll * tree.add(str[i]) * sum[i + 1]) % mod;
ans %= mod;
}
res = tree.Allsum();
ans = ((res * (res - 1) / 2 % mod - ans) % mod + mod) % mod;
printf("%lld\n", ans);
return 0;
}