解决字符串中回文串的相关问题
/*
---------------回文自动机PAM---------------
- 传入字符串下标从0开始
- 本质不同的回文子串个数
- 所有回文子串个数
- 每种回文串出现的次数 cnt(需要get_cnt)
- 每种回文串的长度 len
- 以下标 i 为结尾的回文串的个数 sed
- 每个回文串在原串中出现的起始位置 record
*/
struct PAM {
string s;
int n;
int nxt[N][26];
int fail[N]; // 当前节点最长回文后缀的节点
int len[N]; // 当前节点表示的回文串的长度
int cnt[N]; // 当前节点回文串的个数, 在getcnt后可得到全部
int sed[N]; // 以当前节点为后缀的回文串的个数
int record[N]; // 每个回文串在原串中出现的位置
int tot; // 节点个数
int last; // 上一个节点
void init()
{
tot = 0;
memset(fail, 0, sizeof fail);
memset(cnt, 0, sizeof cnt);
memset(sed, 0, sizeof sed);
memset(len, 0, sizeof len);
memset(nxt, 0, sizeof nxt);
memset(record, 0, sizeof record);
}
int newnode(int lenx)
{
for (int i = 0; i < 26; i++)
nxt[tot][i] = 0;
sed[tot] = cnt[tot] = 0;
len[tot] = lenx;
return tot;
}
void build(string ss)
{
tot = 0;
newnode(0);
tot = 1, last = 0;
newnode(-1);
fail[0] = 1;
n = ss.size();
s = " " + ss;
}
int getfail(int x, int n)
{
while (n - len[x] - 1 <= 0 || s[n - len[x] - 1] != s[n])
x = fail[x];
return x;
}
void insert(char cc, int pos)
{
int c = cc - 'a';
int p = getfail(last, pos);
if (!nxt[p][c])
{
tot++;
newnode(len[p] + 2);
fail[tot] = nxt[getfail(fail[p], pos)][c];
len[tot] = len[p] + 2;
sed[tot] = sed[fail[tot]] + 1;
nxt[p][c] = tot;
}
last = nxt[p][c];
cnt[last]++;
record[last] = pos;
}
void insert()
{
for (int i = 1; i <= n; i++) insert(s[i], i);
}
void get_cnt()
{
for (int i = tot; i > 0; i--)
cnt[fail[i]] += cnt[i];
}
int get_diff_cnt() // 本质不同的回文子串个数
{
return tot - 1;
}
int get_all_cnt() // 所有回文子串个数(本质相同的多次计算)
{
int sum = 0;
get_cnt();
for (int i = 2; i <= tot; i ++ ) sum += cnt[i];
return sum;
}
} pam;
//------------------------------------