题源链接:https://nanti.jisuanke.com/t/41389
借助这道题学会了回文自动机wwww,超开心~
打比赛时一直没法静下来好好读读回文自动机的板子,盲改一直gg 毕竟连ac自动机的基础都没有
前几天抽时间好好学了学,发现这玩意儿原理挺容易理解的,板子也是真的好写,学习过程并未遇到什么阻碍
先来说这道题,就是sum (回文串出现个数 * 回文串中不同字母个数)
统计本质不同回文串各自出现的个数,个数在建树的时候各个结点统计一下,然后逆序下放次数,即长回文串出现,其包含的短的必定出现:
for (int i = cnt; i >= 2; i--) {
ans[fail[i]] += ans[i];
}
再写个搜索算一下各回文串字母个数,我a题的时候写的搜索长什么样已经忘了,这里用一个在别处看到的状压dfs,性能更为优秀:
void dfs(int now, int state, int c) {
for (int i = 0; i < 26; ++i) {
if (ch[now][i]) {
int nc = c, nstate = state;
if ((state & (1 << i)) == 0) nstate |= 1 << i, nc++;
anss += nc * ans[ch[now][i]];
dfs(ch[now][i], nstate, nc);
}
}
}
(不得不说别人写的状压看起来真的优雅...
完整code:
#include <bits/stdc++.h>
#define numm ch-48
#define pd putchar(' ')
#define pn putchar('\n')
#define pb push_back
#define fi first
#define se second
#define fre1 freopen("1.txt","r",stdin)
#define fre2 freopen("2.txt","w",stdout)
typedef long long int ll;
typedef long long int LL;
using namespace std;
template<typename T>
void read(T& res) {
bool flag = false;
char ch;
while (!isdigit(ch = getchar())) (ch == '-') && (flag = true);
for (res = numm; isdigit(ch = getchar()); res = (res << 1) + (res << 3) + numm);
flag && (res = -res);
}
template<typename T>
void write(T x) {
if (x < 0) putchar('-'), x = -x;
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
static auto _ = []()
{
ios::sync_with_stdio(false);
cin.tie(0);
return 0;
}();
#pragma GCC optimize(3,"Ofast","inline")
//
const int maxn = 3 * 1e5 + 5;
int fail[maxn], len[maxn];
int ch[maxn][26] = { 0 };
int cnt, last, k;
ll ans[maxn] = { 0 };
ll anss = 0;
void dfs(int now, int state, int c) {
for (int i = 0; i < 26; ++i) {
if (ch[now][i]) {
int nc = c, nstate = state;
if ((state & (1 << i)) == 0) nstate |= 1 << i, nc++;
anss += nc * ans[ch[now][i]];
dfs(ch[now][i], nstate, nc);
}
}
}
int main() {
string s;
cin >> s;
s = '#' + s;
fail[0] = 1, fail[1] = 0;
len[0] = 0, len[1] = -1;
last = 0, cnt = 1;
for (int i = 1; i < s.length(); i++) {
while (s[i - len[last] - 1] != s[i]) {
last = fail[last];
}
if (!ch[last][s[i] - 'a']) {
len[++cnt] = len[last] + 2;
int j = fail[last];
while (s[i - len[j] - 1] != s[i]) {
j = fail[j];
}
fail[cnt] = ch[j][s[i] - 'a'];
ch[last][s[i] - 'a'] = cnt;
}
ans[ch[last][s[i] - 'a']]++;
last = ch[last][s[i] - 'a'];
}
for (int i = cnt; i >= 2; i--) {
ans[fail[i]] += ans[i];
}
dfs(0, 0, 0); dfs(1, 0, 0);
write(anss);
return 0;
}