题意是让你在一个字符串里找三元组,(i,j,k),[i...j],[j+1...k]都是回文,求∑∑i×k
这题主要要看成,枚举中间的j,然后找到j两边的所有回文串
两种解法:第一种回文树,假设左右的回文串长度为a,b
∑∑i×k=∑∑∑(j−1−a+1)×(j+b−1)
=∑presum∗sufsum∗j2
+j∗(sufsum∗prenum−presum∗sufnum−prenum∗sufnum)
−presum∗(sufsum−sufnum)
prenum表示j−1为最右端点的回文串的个数,presum为这些回文串的长度
suf也是,所以只要正着回文树一遍,倒着再来一遍
卡内存,第二遍要直接求和
第二种解法,mancher
∑∑i×k=∑jsum[0][j−1]∗sum[1][j]
首先mancher预处理,直接保留加了符号的串即可,因为偶数位是原来的串
sum[2][i],0表示i为右端点的所有回文串的左端点之和,1表示i为左端点的所有回文串的右端点之和
对于i位置,mancher求出来的p[i]表示包含i之后左右可以延伸多少距离
所以l=i−p[i]+1,r=i+p[i]−1
对于l为左端点的话,sum[1][l]+=r,sum[1][l+1]+=(r−1),这样一个首项为r,公差为−1的等差数列,一直到i为止
这个等差数列怎么求呢
可以给sum[1][l]+=r,sum[1][i+1]+=−i
delta[1][l+1]+=−1,delta[1][i+1]+=1
然后每次求的时候delta[1][i]+=delta[1][i−1],sum[1][i]+=sum[1][i−1]+delta[1][i]
sum[0]一样做,这个等差数列求的方法很好,可以O(n)的算出区间加等差数列,如果没有其他修改操作的话
代码(回文树):
#include <map>
#include <set>
#include <ctime>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker,"/STACK:102400000,102400000")
using namespace std;
#define MAX 1000005
#define MAXN 1000005
#define maxnode 205
#define sigma_size 26
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000
//const int prime = 999983;
const int INF = 0x3f3f3f3f;
const LL INFF = 0x3f3f;
const double pi = acos(-1.0);
const double inf = 1e18;
const double eps = 1e-4;
const LL mod = 1e9+7;
const ull mx = 133333331;
/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/
struct Palindromic_Tree {
int next[MAX][sigma_size] ;//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAX] ;//fail指针,指向这个节点表示的回文串的最长后缀回文串的开头。
//int cnt[MAX] ;//这个节点上本质相同的回文串的个数(最后count到cnt[1],表示所有回文串的个数)
int num[MAX] ; //节点i表示的回文串的后缀回文串的个数
int sum[MAX];
int len[MAX] ;//len[i]表示节点i表示的回文串的长度
char S[MAX] ;//存放添加的字符
int last ;//指向上一个字符所在的节点,方便下一次add
int n ;//字符数组指针
int p ;//节点指针
int newnode ( int l ) {//新建节点
for ( int i = 0 ; i < sigma_size ; ++ i ) next[p][i] = 0 ;
//cnt[p] = 0 ;
num[p] = 0 ;
sum[p] = 0;
len[p] = l ;
return p ++ ;
}
void init () {//初始化
p = 0 ;
newnode ( 0 ) ;
newnode ( -1 ) ;
last = 0 ;
n = 0 ;
S[n] = -1 ;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1 ;
}
int get_fail ( int x ) {//和KMP一样,失配后找一个尽量最长的
while ( S[n - len[x] - 1] != S[n] ) x = fail[x] ;
return x ;
}
pii add ( char b ) {
//c -= 'a' ;
S[++ n] = b ;
int c=b-'a';
int cur = get_fail ( last ) ;//通过上一个回文串找这个回文串的匹配位置
if ( !next[cur][c] ) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode ( len[cur] + 2 ) ;//新建节点
fail[now] = next[get_fail ( fail[cur] )][c] ;//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now ;
num[now] = num[fail[now]] + 1 ;
sum[now]=sum[fail[now]]+len[now];
if(sum[now]>=mod) sum[now]-=mod;
}
last = next[cur][c] ;
//cnt[last] ++ ;
return mk(num[last],sum[last]);
}
}pt;
char s[MAX];
int prenum[MAX];
int presum[MAX];
int main(){
while(~scanf("%s",s+1)){
pt.init();
int len=strlen(s+1);
for(int i=1;i<=len;i++){
pii p=pt.add(s[i]);
prenum[i]=p.first;
presum[i]=p.second;
}
pt.init();
LL ans=0;
for(int i=len;i>1;i--){
pii p=pt.add(s[i]);
LL tmp=(((LL)prenum[i-1]*i%mod)*((LL)i*p.first%mod)%mod+i*((LL)p.second*prenum[i-1]%mod-(LL)p.first*presum[i-1]%mod-(LL)p.first*prenum[i-1]%mod)%mod-(LL)presum[i-1]*(p.second-p.first)%mod)%mod;
ans+=(tmp+mod)%mod;
if(ans>=mod) ans-=mod;
}
cout<<ans<<endl;
}
return 0;
}
代码(mancher):
#include <map>
#include <set>
#include <ctime>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker,"/STACK:102400000,102400000")
using namespace std;
#define MAX 1000005
#define MAXN 1000005
#define maxnode 205
#define sigma_size 26
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000
//const int prime = 999983;
const int INF = 0x3f3f3f3f;
const LL INFF = 0x3f3f;
const double pi = acos(-1.0);
const double inf = 1e18;
const double eps = 1e-4;
const LL mod = 1e9+7;
const ull mx = 133333331;
/*****************************************************/
inline void RI(int &x) {
char c;
while((c=getchar())<'0' || c>'9');
x=c-'0';
while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
/*****************************************************/
char a[MAX];
const LL inv=500000004;
struct Manacher {
// 原串 a[i]: w a a b w s w f d
// 新串 s[i]: # w # a # a # b # w # s # w # f # d #
// 辅助数组 p[i]: 1 2 1 2 3 2 1 2 1 2 1 4 1 2 1 2 1 2 1
// p[i] := 新串以 s[i] 为中心向右延伸的回文距离 + 1 (自己)
// p[i]-1 := 原串以 s[i] 为中心的回文长度
static const int M = MAX << 1;
char s[M];
int n, p[M];
LL delta[2][M], sum[2][M];
void init(char* a) {
s[0] = '@'; s[1] = '#'; n = 2;
for(int i = 1; a[i]; ++i)
s[n++] = a[i], s[n++] = '#';
s[n] = 0;
}
int gao() {
int mx = 0, id, ret = 0;
for(int i = 1; i < n; ++i) {
p[i] = mx > i ? min(mx - i, p[2 * id - i]) : 1;
while(s[i - p[i]] == s[i + p[i]]) ++p[i];
if(mx < i + p[i]) mx = i + p[i], id = i;
ret = max(ret, p[i] - 1);
}
return ret;
}
void add(LL &x,LL y){
if(y<0) y+=mod;
x+=y;
if(x>=mod) x-=mod;
}
LL solve(){
mem(delta,0);
mem(sum,0);
for(int i=1;i<n;i++){
int l=i-p[i]+1;
int r=i+p[i]-1;
add(sum[0][l],r);
add(delta[0][l+1],-1);
add(sum[0][i+1],-i);
add(delta[0][i+1],1);
add(sum[1][i],i);
add(delta[1][i+1],-1);
add(sum[1][r+1],-l);
add(delta[1][r+1],1);
}
for(int i=0;i<2;i++){
for(int j=1;j<n;j++){
add(delta[i][j],delta[i][j-1]);
add(sum[i][j],sum[i][j-1]);
add(sum[i][j],delta[i][j]);
}
}
LL ans=0;
for(int i=2;i<n-2;i+=2){
ans+=((LL)sum[1][i]*inv%mod)*((LL)sum[0][i+2]*inv%mod)%mod;
if(ans>=mod) ans-=mod;
}
return ans;
}
} man;
int main(){
//freopen("1002.in","r",stdin);
//freopen("froggy.out","w",stdout);
while(~scanf("%s",a+1)){
man.init(a);
man.gao();
cout<<man.solve()<<endl;
}
return 0;
}