题目
题解
核心思路
由于总的回文串数量为级别,直接求相交的回文串不好求,考虑正难则反,
设S中的回文串总数为sum,则相交的回文串对数=总对数()-不相交的回文串对数
此题有特别多的做法,大致可分为两大类:
Manacher
设以i为右端点的回文串数量为,以i为左端点的回文串数量为,则不相交的回文串对数为,
用Manacher求出以每个下标为中心的最长回文串的长度,在这个范围内对左边的 g 和右边的 f 做贡献。
随便用一个常数小的数据结构(如BIT)或搞一个差分就完了。(~)
PAM(回文自动机)
还是一样的思路,不过PAM可以更直接地求出以每个下标为右端点的回文串数量(等于沿着fail链跑到根的深度),然后再反着求一遍;
不过用PAM就有一个比较棘手的问题,就是普通的PAM每个点都至少存有26大小的int数组,在这题n≤2e6下会MLE,
有多种解决办法:
#1.邻接表建PAM
常规正解做法,但是这会导致每个点的儿子储存无序,相当于用时间换空间,遍历起来我感觉会超时(不知道为什么,明明可以卡掉,但数据并没有让很多人超时),还可以用一些特殊方法(如手写个平衡树,控制空间常数)来优化,但是不好打;
#2.unordered_map存边
时间常数非常小但空间消耗特别大,如果每个点开一个unordered_map,第一个数据就会MLE。
这里有一个硬知识:多维数组的最后一维空间“对齐”,空间小,遍历常数小。(这个规律其实学过矩阵加速的人都用过,就是矩阵乘法里的改变枚举顺序的优化)
所以通过测试我们发现,开n个unordered_map,与开26个unordered_map、每个里存所有点的一条边,前者空间是后者的好几倍,而后者刚好可以卡过此题。
由于有很多比赛并不支持C++11(如某CCF的...),所以习惯上我不用这种做法。
#3.map存边
用普通map过这题是最难的一种方式,但我还是搞出来了。
由于map时间空间常数较大,且时间是一个log,每个点存一个map会MLE,开26个map虽然空间远小于题目限制(注意map和unordered_map的区别,前者是“远小于”,后者是“卡过”),但会TLE。(可能卡卡常可以过,但我常数就是不行,仰慕某知名博主zxy)
考虑用分块,开个map,每个map存个点的边,空间多了一万KB但没问题,时间刚好可过。
注意
这题的总回文串个数会爆longlong,也就是说无法直接算,要用 2 mod 51123987 的逆元,用欧拉函数算出来是25561994。
代码
PAM+map存边
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<map>
#define ll long long
#define MAXN 2000005
#define MOD 51123987ll
#define N 1000
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
return f?x:-x;
}
int n,p[MAXN];
struct itn{
int num,len,fail;
void CL(){num=len=fail=0;}
}td[MAXN];
pair<int,char>pp;
map<pair<int,char>,int>ch[2005];
int IN=1,las=0;
ll SUM,ans,sum;
char s[MAXN];
inline int getf(int x,int n){
while(s[n-td[x].len-1]!=s[n])x=td[x].fail;
return x;
}
inline int findch(int cu,char c){
pp.first=cu,pp.second=c;
map<pair<int,char>,int>::iterator it=ch[cu/N].find(pp);
if(it==ch[cu/N].end())return 0;
else return it->second;
}
inline int extend(int n){
int cur=getf(las,n),np=findch(cur,s[n]);
if(!np){
np=++IN,td[np].len=td[cur].len+2;
td[np].fail=findch(getf(td[cur].fail,n),s[n]);
pp.first=cur,pp.second=s[n];
ch[cur/N][pp]=np,td[np].num=td[td[np].fail].num+1;
}las=np;
return td[np].num;
}
int main()
{
td[1].len=-1,td[0].fail=td[1].fail=1;
n=read();
scanf("%s",s+1);
for(int i=1;i<=n;i++)p[n-i+1]=extend(i),(SUM+=p[n-i+1])%=MOD;
for(int i=1;(i<<1)<=n;i++)swap(s[i],s[n-i+1]);
while(IN>=0)td[IN].CL(),IN--;
for(int i=0;i<=n/N;i++)ch[i].clear();
las=0,IN=1,td[1].len=-1,td[0].fail=td[1].fail=1;
for(int i=1;i<=n;i++)
extend(i),ans=(sum*p[i]+ans)%MOD,sum=(sum+td[las].num)%MOD;
ans=(SUM%MOD*(SUM-1)%MOD*25561994ll%MOD-ans+MOD)%MOD;
printf("%lld\n",ans);
return 0;
}