回文树(回文自动机),他的功能如下:
- 求前缀字符串中的本质不同的回文串种类
- 求每个本质不同回文串的个数
- 以下标为结尾的回文串个数/种类
- 每个本质不同回文串包含的本质不同回文串种类
- next[][]:类似于字典树,指向当前字符串在两段同时加上一个字符
- fail[]:fail指针,类似于AC自动机,返回失配后与当前结尾的最长回文串本质上不同的最长回文后缀
- cnt[]:在最后统计后它可以表示形如以为结尾的回文串中最长的那个串个数
- num[]:表示以结尾的回文串的种类数
- len[]:表示以为结尾的最长回文串长度
- s[]:存放添加的字符
- last:表示上一个添加的字符的位置
- n:表示字符数组的第几位
- p:表示树中节点的指针
构造回文树需要的空间复杂度为O(N * 字符集大小),时间复杂度为O(N * log(字符集大小) )
模板:
#include<bits/stdc++.h>
typedef long long LL;
const int MAXN = 3*100005 ;
const int N = 26 ;
char s[MAXN];
struct Palindromic_Tree {
int nxt[MAXN][N] ;//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN] ;//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN] ;
int num[MAXN] ;
int sz[MAXN] ;//记录回文串结尾i时,能够构成的回文串的数量
int v[MAXN] ;//标记数组
int t[MAXN] ;//记录贡献值:2 1
int len[MAXN] ;//len[i]表示节点i表示的回文串的长度
int S[MAXN] ;//存放添加的字符
int last ;//指向上一个字符所在的节点,方便下一次add
int n ;//字符数组指针
int p ;//节点指针
int newnode ( int l ) {//新建节点
for ( int i = 0 ; i < N ; ++ i ) nxt[p][i] = 0 ;
cnt[p] = 0 ;
num[p] = 0 ;
len[p] = l ;
return p ++ ;
}
void init () {//初始化
p = 0 ;
newnode ( 0 ) ;
newnode ( -1 ) ;
last = 0 ;
n = 0 ;
S[n] = -1 ;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1 ;
}
int get_fail ( int x ) {//和KMP一样,失配后找一个尽量最长的
while ( S[n - len[x] - 1] != S[n] ){
x = fail[x] ;
}
return x ;
}
void add(int c){
c -= 'a';
S[++n] = c;
int cur = get_fail(last);
if(!nxt[cur][c]){
int now = newnode(len[cur] + 2);
fail[now] = nxt[get_fail(fail[cur])][c];
nxt[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = nxt[cur][c];
cnt[last]++;
}
void dfs(int x){
t[x] = (v[x] == 0) + (v[fail[x]] == 0);
v[x]++; v[fail[x]]++;
sz[x] = 1;
for(int i = 0; i < 26; i++){
if(!nxt[x][i]) continue;
dfs(nxt[x][i]);
sz[x] += sz[nxt[x][i]];
}
v[x]--;
v[fail[x]]--;
}
/**求回文串有多少个本质不同的回文串是另一个回文串的子串,共有多少对**/
LL count1 () {
LL ans = 0;
dfs(0); dfs(1);
for(int i = 2; i < p; i++){
ans += sz[i] * t[i] - 1;
}
return ans;
}
/**统计本质相同的回文串出现的次数**/
void count2(){
for(int i = p-1; i >= 0; i--)
cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
//cnt表示长度为len的回文串的出现次数
}
}PT ;
int main()
{
int T;
scanf("%d", &T);
for(int cas = 1; cas <= T; cas++){
PT.init();
scanf("%s", s);
int len = strlen(s);
for(int i = 0; i < len; i++){
PT.add(s[i]);
}
LL ans = PT.count1();
printf("Case #%d: %lld\n", cas, ans);
}
return 0;
}