链接
题解
先用 m a n a c h e r manacher manacher求出每个位置回文串的最长长度
这个答案怎么算呢?
如果两个回文串相交,那么肯定其中一个回文串的末尾在另一个回文串所涵盖的区间中出现了
我求出 s 1 [ i ] s1[i] s1[i]表示 i i i位置有多少回文串的末尾, s 2 [ i ] s2[i] s2[i]表示 i i i位置这里被多少个回文串覆盖
那么按照上面说的意思, ∑ i = 1 n s 1 [ i ] ( s 2 [ i ] − 1 ) \sum_{i=1}^n s1[i](s2[i]-1) ∑i=1ns1[i](s2[i]−1)似乎就是答案,我 − 1 -1 −1是为了不能把自己和自己覆盖的那种情况算进去。但是注意我说“似乎”,其实这个式子算出来多了一部分,就是两个回文串结尾重合的情况,所以再减去 ∑ i = 1 m 1 2 ( s 1 [ i ] 2 ) \sum_{i=1}^m \frac{1}{2} \binom{s1[i]}{2} ∑i=1m21(2s1[i])即可
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 4000010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct Manacher
{
int r[maxn], p[maxn], n;
void clear(){cl(r), cl(p);}
void calc(char *s, int N)
{
n=N;
int i, j, mx(0), center;
r[0]=-2;
for(i=1;i<=N;i++)r[2*i]=s[i];
for(i=1;i<=N;i++)r[2*i-1]=-1;
r[2*N+1]=-1;
for(i=1;i<=2*N+1;i++)
{
if(mx>=i)p[i]=min(p[2*center-i],mx-i+1);
else p[i]=1;
while(r[i-p[i]]==r[i+p[i]])p[i]++;
if(i+p[i]-1>mx)
{
mx=i+p[i]-1;
center=i;
}
}
}
}mnc;
char s[maxn];
ll n, s1[maxn], s2[maxn];
#define mod 1'000'000'007
int main()
{
int i;
n=read();
scanf("%s",s+1);
mnc.calc(s,n);
rep(i,1,2*n)
{
ll len=mnc.p[i]-1;
if(len==0)continue;
if(i&1)
{
ll half=len>>1;
s1[i+1>>1]++;
s1[(i+1>>1)+half]--;
s2[(i+1>>1)-half]++;
s2[(i+1>>1)]--;
s2[(i+1>>1)+1]--;
s2[(i+1>>1)+half+1]++;
}
else
{
ll half=len+1>>1;
s1[i>>1]++;
s1[(i>>1)+half]--;
s2[(i>>1)-half+1]++;
s2[(i>>1)+1]--;
s2[(i>>1)+1]--;
s2[(i>>1)+half+1]++;
}
}
rep(i,1,n)s1[i]+=s1[i-1];
rep(i,1,n)s2[i]+=s2[i-1];
rep(i,1,n)s2[i]+=s2[i-1];
ll ans=0;
rep(i,1,n)
{
(ans += (ll)s1[i]*(s2[i]-1))%=mod;
(ans -= (ll)s1[i]*(s1[i]-1)>>1 )%=mod;
}
printf("%lld",(ans+mod)%mod);
return 0;
}