题意:
给出一个字符串 S,定义一个字符串的价值为:出现字母的种类数 × 字符串出现的次数。求 S 所有回文子串的价值之和。
思路:
本题除了求每个本质不同子串的出现次数外,还需要求每个本质不同子串出现的字母种类数。
可以在 PAM 的每个节点额外维护一个长度为 26 的数组,表示这个点代表的回文串用到了哪些字母,在新建节点的时候顺便维护即可。具体看代码。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
int lstp, tot = 1, fa[N], ch[N][26], len[N];
long long cnt[N];
int type[N][26];
char s[N];
int getfa(int p, int idx)
{
while(idx - len[p] - 1 < 0 || s[idx - len[p] - 1] != s[idx]){
p = fa[p];
}
return p;
}
void insert(int c, int idx)
{
if(!idx){
fa[0] = 1, fa[1] = len[0] = 0, len[1] = -1;
}
int p = getfa(lstp, idx); //上一个节点
if(!ch[p][c]){
int np = ++tot;
int v = getfa(fa[p], idx);
fa[np] = ch[v][c];
len[np] = len[p] + 2;
ch[p][c] = np;
}
lstp = ch[p][c]; //走到最新的节点
memcpy(type[lstp], type[p], sizeof type[p]); //先将上一个节点代表子串包含的字符复制
type[lstp][c] = 1; //纳入当前字符
cnt[lstp]++;
}
signed main()
{
scanf("%s", s);
for(int i = 0; s[i]; ++i){
insert(s[i] - 'a', i);
}
long long sum = 0;
for(int i = tot; i > 0; --i){
cnt[fa[i]] += cnt[i]; //拓扑更新每个节点代表子串的出现次数
int tmp = 0;
for(int j = 0; j < 26; ++j){
if(type[i][j]) ++tmp;
}
sum += tmp * cnt[i];
}
printf("%lld\n", sum);
return 0;
}