有个很显然的做法就是建出回文自动机,然后在fail树上倍增找到长度小于等于一半的最长的回文后缀。
这样复杂度是 O(|S|log|S|)
但实际上每个点也可以维护一个half指针指向它的合法回文后缀,那么找当前点的half就沿着父节点的half指针的fail指针网上爬就好了
复杂度就是建回文自动机的复杂度 O(|S|log|Σ|)
fail树倍增
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int N=100010;
int t,n,a[N],fa[N][20];
ll tot[N],w[N];
int nxt[N][30],len[N],p,cnt;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void rea(int &x){
char c=nc(); x=0;
for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}
inline int rea(int *x){
char c=nc(); x=0; int len=0;
for(;c>'z'||c<'a';c=nc());for(;c>='a'&&c<='z';a[++len]=c-'a',c=nc()); return len;
}
inline ll check(int x){
int cur=len[x]>>1;
for(int i=17;~i;i--)
if(len[fa[x][i]]>cur) x=fa[x][i];
x=fa[x][0];
if(len[x]==cur) return w[x];
return 0;
}
inline ll extend(int c,int r){
while(a[r-len[p]-1]!=c) p=fa[p][0];
if(!nxt[p][c]){
int cur=++cnt,k=fa[p][0]; len[cur]=len[p]+2;
while(a[r-len[k]-1]!=c) k=fa[k][0];
fa[cur][0]=nxt[k][c];
for(int i=1;i<=17;i++) fa[cur][i]=fa[fa[cur][i-1]][i-1];
nxt[p][c]=cur; tot[cur]=tot[fa[cur][0]]+(w[cur]=check(cur)+1);
}
p=nxt[p][c];
return tot[p];
}
int main(){
rea(t);
while(t--){
n=rea(a); ll ans=0;
memset(nxt,0,sizeof(nxt));
fa[0][0]=fa[0][1]=1; a[0]=-1; len[1]=-1; p=0; cnt=1;
for(int i=1;i<=n;i++)
ans+=extend(a[i],i);
cout<<ans<<endl;
}
return 0;
}
half指针
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int N=100010;
int t,n,a[N],fail[N];
ll tot[N],w[N];
int nxt[N][30],len[N],half[N],p,cnt;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void rea(int &x){
char c=nc(); x=0;
for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}
inline int rea(int *x){
char c=nc(); x=0; int len=0;
for(;c>'z'||c<'a';c=nc());for(;c>='a'&&c<='z';a[++len]=c-'a',c=nc()); return len;
}
inline ll extend(int c,int r){
while(a[r-len[p]-1]!=c) p=fail[p];
if(!nxt[p][c]){
int cur=++cnt,k=fail[p]; len[cur]=len[p]+2;
while(a[r-len[k]-1]!=c) k=fail[k];
fail[cur]=nxt[k][c];
int t=half[p];
while(a[r-len[t]-1]!=c || len[nxt[t][c]]>(len[cur]>>1)) t=fail[t];
t=nxt[t][c];
half[cur]=t; tot[cur]=tot[fail[cur]]; w[cur]=1;
if(len[t]==(len[cur]>>1)) w[cur]+=w[t];
tot[cur]+=w[cur];
nxt[p][c]=cur;
}
p=nxt[p][c];
return tot[p];
}
int main(){
rea(t);
while(t--){
n=rea(a); ll ans=0;
memset(nxt,0,sizeof(nxt));
memset(w,0,sizeof(w));
memset(tot,0,sizeof(tot));
fail[0]=fail[1]=half[0]=half[1]=1; a[0]=-1; len[1]=-1; p=0; cnt=1;
for(int i=1;i<=n;i++)
ans+=extend(a[i],i);
cout<<ans<<endl;
}
return 0;
}