题意:给你一个字符串,将其中所有本质不同的回文子串求出来放在一个集合。问集合中一个串是另一个串的子串的对数是多少?
题解:网上看了挺多题解的,但觉得有些没写清楚,所以写了一篇自己的理解…
首先用回文树将所有本质不同的回文串求出来。然后加上fail数组的边,回文树就变成一颗树,且每一个节点都有一条回边的‘图’。问题就变成了找图中有多少对可达点?如果放在一般图中这个问题是不好求的,但由于回文树有比较优秀的性质。所以我们可以O(N) 的时间复杂度求解。
具体过程如下:首先不考虑fail边,把回文树单单看成一棵树,然后沿着树边进行dfs,将当前节点的祖先都染色标记。即染父亲节点,再染fail指针指向的节点,如果已经染色则不用再染(父亲节点一定未被染色),途中记录一下当前已经染色的个数。则这些点就是可达当前节点的节点,累计入答案即可。
为什么这样做是对的?首先沿着树边(ch数组)走,则表示当前回文串左右各添加一个字符的所有父亲串转移过来。而标记fail数组表示以当前回文串的右端点结尾的最长回文子串转移过来?为什么只转移一次就够了。按道理我们应该沿着fail指向的节点继续往前遍历完可以转移到该节点的祖先。答案就是沿着树边走,我们可以保证fail指向节点的所有祖先已经遍历过了。
简单证明:
如图表示又下面的字符串左右添加一个字符变成上面的字符串(沿着树边走)
设上面的串为A,下面的串为B。len[]表示最长回文后缀。则len[A]和len[B]的关系只有两种,len[A]=len[B]+2,或者len[A]=len[B]=1。因为上面的串包含下面的串,上面最长回文后缀一定包含下面的最长回文后缀,且长度差为左右两个相同字符。若不包含,则说明各自的回文后缀为1。这个条件说明了沿着树边走,当前串的fail指向的节点是其父亲fail指向节点的儿子或者两者为1。这样我们就不用沿着fail一直往上更新了。所以时间复杂度也为O(N)。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 300005 ;
const int SIZE = 26 ;
ll ret;
struct Palindromic_Tree {
int ch[maxn][SIZE] ;//ch指针,ch指针和字典树类似,指向的串为当前串两端加上同一个字符构成
//如果空间要求太高,就用邻接表实现
int fail[maxn] ;//fail指针,失配后跳转到fail指针指向的节点
int cnt[maxn] ; //表示i节点在s中出现的次数(建树时求出的不是完全的,最后count()函数跑一遍以后才是正确的)
int num[maxn] ; //fail指针的深度
int len[maxn] ;//len[i]表示节点i表示的回文串的长度(一个节点表示一个回文串)
int vis[maxn] ;
int S[maxn] ;//存放添加的字符
int last ;//指向新添加一个字母后所形成的最长回文串表示的节点。
int n ;//表示添加的字符个数。
int tot ;//表示添加的节点个数。
int newnode ( int l ) {//新建节点,长度为l
for ( int i = 0 ; i < SIZE ; ++ i ) ch[tot][i] = 0 ;
cnt[tot] = 0 ;
num[tot] = 0 ;
len[tot] = l ;
return tot ++ ;
}
void init () {//初始化
tot = 0 ;
newnode ( 0 ) ; //偶数回文串的根
newnode ( -1 ) ; //奇数回文串的根
last = 0 ;
n = 0 ;
S[n] = -1 ;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1 ;
}
int getfail ( int x ) {//和KMP一样,失配后找一个尽量最长的
while ( S[n - len[x] - 1] != S[n] ) x = fail[x] ;
return x ;
}
void add ( int c ) {
c -= 'a' ; //具体问题具体分析,有可能是数字串
S[++ n] = c ;
int cur = getfail (last) ;//通过上一个回文串找这个回文串的匹配位置
if ( !ch[cur][c] ) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode ( len[cur] + 2 ) ;//新建节点
fail[now] = ch[getfail ( fail[cur] )][c] ;//和AC自动机一样建立fail指针,以便失配后跳转
ch[cur][c] = now ;
num[now] = num[fail[now]] + 1 ;
}
last = ch[cur][c] ;
cnt[last] ++ ;
}
void dfs(int u,ll sum){
vis[u]=1;
ret+=sum;
for(int i=0;i<26;i++){
int v = ch[u][i];
if(!v)continue;
if(vis[fail[v]]){
dfs(v,sum+(u!=0&&u!=1));
}else{
int d = fail[v]!=0&&fail[v]!=1;
vis[fail[v]]=1;
dfs(v,sum+(u!=0&&u!=1)+d);
vis[fail[v]]=0;
}
}
vis[u]=0;
}
} T;
char s[maxn];
int main() {
int tt;
scanf("%d",&tt);
for(int kase=1;kase<=tt;kase++) {
T.init();
scanf("%s",s);
int len=strlen(s);
for(int i=0; i<len; i++) {
T.add(s[i]);
}
ret=0;
T.dfs(1,0);T.dfs(0,0);
printf("Case #%d: %lld\n",kase,ret);
}
return 0;
}