题意:
找出字符串中所有三元组(i,j,k)使得s[i,j]和s[j+1,k]均为回文串。问所有三元组的i*k的和是多少。
分析:
令中间的j固定,然后计算两边所有的i和k。可以计算出 ∑i ∑ i , ∑k ∑ k ,然后这两个和相乘,就是j固定时所有的i*k的和。
对于第j个字符,可以在回文树上计算出,所有以字符j结尾的回文串个数num[j]和这些回文串的长度总和sum[j],有了这两个量就可以计算出 ∑i ∑ i 。
∑k ∑ k 只需要反向遍历字符串建回文树即可。这儿要注意一下需要把反向字符串建立出来然后作为参数传进去,不能只在insert函数中写对c,否则while(s[n-1-len[x]] != s[n]) x = fail[x];这个地方会出错。
代码:
#include <bits/stdc++.h>
using namespace std;
#define ms(a,b) memset(a,b,sizeof(a))
#define lson rt*2,l,(l+r)/2
#define rson rt*2+1,(l+r)/2+1,r
typedef unsigned long long ull;
typedef long long ll;
const int MAXN=1e6+5;
const double EPS=1e-8;
const int INF=0x3f3f3f3f;
const int MOD = 1e9+7;
struct palin_tree{
int ch[MAXN][26], fail[MAXN], len[MAXN], last, tot, num[MAXN], sum[MAXN];
void init(){
ms(num,0); ms(sum,0); ms(ch,0); ms(fail,0); ms(len,0);
last = 0; len[1] = -1; tot = 1; fail[0] = 1;
}
int insert(int c, int n, char *s){
int x = last;
while(s[n-1-len[x]] != s[n]) x = fail[x];
if(!ch[x][c]){
int v = ++tot, k = fail[x];
len[v] = len[x] + 2;
while(s[n-1-len[k]] != s[n]) k = fail[k];
fail[v] = ch[k][c];
ch[x][c] = v;
num[v] = num[fail[v]] + 1;
sum[v] = ((ll)sum[fail[v]] + len[v])%MOD;
}
last = ch[x][c];
return last;
}
}t;
char s1[MAXN], s2[MAXN];
int tmp[MAXN];
int main(){
ios::sync_with_stdio(false);
while(~scanf("%s",s1+1)){
s1[0] = s2[0] = -1;
ms(tmp,0);
int n = strlen(s1+1);
for(int i=1;i<=n;i++){
s2[i] = s1[n-i+1];
}
t.init();
for(int i=n;i>=1;i--){
int x = t.insert(s1[i]-'a',n-i+1,s2);
tmp[i] = (t.sum[x]+(ll)(i-1)*t.num[x])%MOD;
}
ll ans = 0;
t.init();
for(int i=1;i<n;i++){
int x = t.insert(s1[i],i,s1);
ans=(ans+(((ll)t.num[x]*(i+1)-t.sum[x])%MOD)*tmp[i+1]%MOD)%MOD;
}
cout << ans << endl;
}
return 0;
}