题意:给你一个长度为n的序列,问你有多少相交的回文串。
做法:可知求相交的比较难,但是求不相交的却很简单,先用manacher算法o(n)的求出每个点的回文串最长有多长。再求出start,end数组,分别代表以i为开头的回文串有几个,以i为结尾的回文串有几个。这可以利用标记再加上前缀和求,最后要求的不相交个数是:sum( start[j] * sum(end[i]| 1<= i <= j) | 1<=i <= n)。
PS:因为n高达2*10^6,所以如果全是一个字母,那么回文串总数是x = n*(n-1)/2,在求C(x,2)时要时刻注意溢出问题。这里WA了好几发。
AC代码:
#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<ctype.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdlib>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<cmath>
#include<ctime>
#include<string.h>
#include<string>
#include<sstream>
#include<bitset>
using namespace std;
#define ll __int64
#define ull unsigned long long
#define eps 1e-8
#define NMAX 1000000000
//#define MOD 51123987
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define PI acos(-1)
#define ALL(x) x.begin(), x.end()
#define INS(x) inserter(x, x.end())
template<class T>
inline void scan_d(T &ret)
{
char c;
int flag = 0;
ret=0;
while(((c=getchar())<'0'||c>'9')&&c!='-');
if(c == '-')
{
flag = 1;
c = getchar();
}
while(c>='0'&&c<='9') ret=ret*10+(c-'0'),c=getchar();
if(flag) ret = -ret;
}
template<class T> inline T Max(T a, T b){ return a > b ? a : b; }
template<class T> inline T Min(T a, T b){ return a < b ? a : b; }
const int maxn = 2100000+10;
char s[maxn],ch[maxn*2];
ll str[maxn*2];
ll st[maxn],ed[maxn];
ll flag1[maxn],flag2[maxn];
const ll MOD = 51123987;
inline void add(ll &x, ll y)
{
x += y;
if(x > MOD) x -= MOD;
}
int main()
{
#ifdef GLQ
// freopen("input.txt","r",stdin);
freopen("o.txt","r",stdin);
// freopen("o.txt","w",stdout);
#endif
int len;
while(~scanf("%d%s",&len,s))
{
int nlen = 1;
for(int i = 0; i < len; i++)
{
ch[nlen++] = '#';
ch[nlen++] = s[i];
}
ch[nlen++] = '#';
ch[0] = '$';
int p0,p = 0;
for(int i = 1; i < nlen; i++)
{
if(i > p)
{
int j = 1;
for( ; ch[i-j] == ch[i+j]; j++);
str[i] = j;
p0 = i; p = i+j-1;
}
else
{
int dui = 2*p0-i;
if(str[dui] < p-i+1) str[i] = str[dui];
else
{
int j = p-i+1;
for( ; ch[i-j] == ch[i+j]; j++);
str[i] = j;
p0 = i; p = i+j-1;
}
}
}
memset(st,0,sizeof(st));
memset(ed,0,sizeof(ed));
ll ge = 0;
for(int i = 1; i < nlen; i++)
{
if(i&1) ge += (str[i]-1LL)/2LL;
else ge += (str[i]-1LL)/2LL+1LL;
}
// cout<<ge<<" "<<ge%MOD<<endl;
if(ge%2LL == 0) ge = ((ge/2LL)%MOD)*((ge-1LL)%MOD);
else ge = ((ge-1LL)/2LL%MOD)*(ge%MOD);
ge %= MOD;
memset(flag1,0,sizeof(flag1));
memset(flag2,0,sizeof(flag2));
for(int i = 1; i < nlen; i++)
{
if(i&1)
{
if(str[i] == 1) continue;
add(flag1[(i-str[i]+2)/2],1);
add(flag1[(i+1)/2],-1);
add(flag2[(i+str[i]-2)/2],1);
add(flag2[(i-1)/2],-1);
}
else
{
add(flag1[(i-str[i]+2)/2],1);
add(flag1[(i+2)/2],-1);
add(flag2[(i+str[i]-2)/2],1);
add(flag2[(i-2)/2],-1);
}
}
for(int i = 1; i <= len; i++)
{
add(flag1[i],flag1[i-1]);
st[i] = flag1[i];
}
// for(int i = 1; i <= len; i++)
// cout<<st[i]<<" ";
// cout<<endl;
flag2[len+1] = 0;
for(int i = len; i >= 1; i--)
{
add(flag2[i],flag2[i+1]);
ed[i] = flag2[i];
}
// for(int i = 1; i <= len; i++)
// cout<<ed[i]<<" ";
// cout<<endl;
ll ans = 0;
for(int i = 2; i <= len; i++)
{
add(ed[i],ed[i-1]);
add(ans,ed[i-1]*st[i]%MOD);
}
printf("%I64d\n",(ge-ans+MOD)%MOD);
}
return 0;
}