这道题思路并不难想(假的!):先用manacher算法求出以s[i]为中心的最长回文子串左右扩展的长度,再分别推出以s[i]结尾和开头的回文子串(注:不一定是最长的。如原串为aacaa,则i=3时,en[i]=2而不是1)数量,然后其中一组乘上另外一组的后缀和(前缀和)相加即可。简单分析一下第一个样例:
原数组的下标 i: 0 1 2
a c a
以s[i]结尾的回文串数量: 1 1 2 -> en[]数组
以s[i]开头的回文串数量: 2 1 1 -> st[]数组
则答案为1×(1+1) + 1×1 = 3
问题主要是如何推出st,en数组。我一开始是这么写的:
for(int i=2;i<l;i++)
{
int k=i+p[i]-1;//最长回文串扩展的右边界(a数组中)
int j=ceil((i*1.0-2)/2);//回文串中心(映射到原s数组下标)
while(j<=(k-2)/2)//同样要映射到s数组的下标
en[j]++,j++;
}
for(int i=0,j=len-1;i<len&&j>=0;i++,j--)
st[i]=en[j];
两层循环果断TLE。。只能换思路。。后来网上一搜发现可以用树状数组来维护。
先附上大佬博客Orz:https://www.cnblogs.com/liyinggang/p/5675916.html
https://blog.csdn.net/hexianhao/article/details/51823113
以en数组为例,假设以i为中心(a[i]不一定是字母,要“映射”到新下标(从1开始)),则从i到i+p[i]-1这些点的en值都要加1。我不太明白为什么把到右端点的数-1,再把到中心的数+1(我感觉刚好反了)。后来我写了一发但还是WA了QAQ...不明白为什么会酱紫55555...可能还是对树状数组的理解不够?把solve函数放在这里,如果有大神知道为什么错了还请不吝赐教^_^
ll st[MAX];//以s[i]开头的回文子串数
ll en[MAX];//以s[i]结尾的回文子串数
ll c[MAX];
ll sum[MAX];
int lowbit(int x)
{
return x&(-x);
}
void update(int i,int val)
{
while(i<=n)
{
c[i]+=val;
i+=lowbit(i);
}
}
int get_sum(int i)
{
int ret=0;
while(i>0)
{
ret+=c[i];
i-=lowbit(i);
}
return ret;
}
void fun()
{
memset(c,0,sizeof(c));
for(int i=2;i<l;i++)
{
int mid;//回文串中心(映射到原s数组下标,但注意下标从1开始)
if(i%2==0)//字母
mid=i/2;
else //'#'
mid=(i+1)/2;
int k=i+p[i]-1;//最长回文串扩展的右边界(a数组中)
int r=k/2+1;//映射到原s数组
if(r<=mid)
continue;
//cout<<"i="<<i<<" l="<<l<<" r="<<r<<endl;
update(mid,1);
update(r,-1);
//c[mid,r]+1
}
}
void solve() //得到st,en数组
{
memset(st,0,sizeof(st));
memset(en,0,sizeof(en));
fun();
for(int i=1;i<=n;i++)
en[i]=get_sum(i);
/*for(int i=1;i<=n;i++)
cout<<"i="<<i<<" en[i]="<<en[i]<<endl;*/
reverse(s,s+n);
manacher();
fun();
for(int i=1;i<=n;i++)
st[i]=get_sum(i);
}
我还是选择继续挣扎。。再一搜题解,发现可以用“差分前缀和”(第一次听说QAQ...)
附上讲解博客Orz:https://blog.csdn.net/hzk_cpp/article/details/80407014
https://www.cnblogs.com/lulizhiTopCoder/p/8384784.html
感觉有点像树状数组,处理思路也差不多。再附上本题的参考博客Orz:
https://blog.csdn.net/gatevin/article/details/44775533
经过长期挣扎。。终于过了。。注意WA点开long long啊!!!!!附上AC代码:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<stack>
#include<queue>
using namespace std;
#define ll long long
typedef pair<int,int>pp;
#define mkp make_pair
#define pb push_back
const int INF=0x3f3f3f3f;
const ll MOD=1e9+(ll)7;
const int MAX=100010;
char s[MAX];
int n;
char a[MAX*2];
int len,p[MAX*2];//以s[i]为中心的最长回文子串右(左)扩展的长度
void manacher()
{
memset(p,0,sizeof(p));
len=0;
a[len++]='$';
a[len++]='#';
for(int i=0;i<n;i++)
{
a[len++]=s[i];
a[len++]='#';
}
a[len]='\0';
int mx=0,id=0;
for(int i=0;i<len;i++)
{
p[i]=(mx-i)?min(p[2*id-i],mx-i):1;
while(a[i+p[i]]==a[i-p[i]])
p[i]++;
if(i+p[i]>mx)
{
mx=i+p[i];
id=i;
}
}
/*for(int i=0;i<len;i++)
cout<<a[i]<<" ";
cout<<endl;
for(int i=0;i<len;i++)
cout<<p[i]<<" ";
cout<<endl;
cout<<"len="<<len<<endl;*/
}
ll st[MAX];//以s[i]开头的回文子串数
ll en[MAX];//以s[i]结尾的回文子串数
ll dp1[MAX],dp2[MAX];//差分
ll sum[MAX];//st的后缀和
void solve() //得到st,en数组
{
memset(dp1,0,sizeof(dp1));
memset(dp2,0,sizeof(dp2));
for(int i=2;i<len;i++)
{
int l,r;
l=i-(p[i]-1);r=i;//a数组下标,要映射到s数组
if(l%2) l/=2;
else l=l/2-1;
r=r/2-1;
if(l<=r)
dp1[l]++,dp1[r+1]--;//对应st数组
l=i;r=i+(p[i]-1);
if(l%2) l/=2;
else l=l/2-1;
r=r/2-1;
if(l<=r)
dp2[l]++,dp2[r+1]--;//对应en数组
}
st[0]=dp1[0];en[0]=dp2[0];
for(int i=1;i<n;i++)
{
st[i]=st[i-1]+dp1[i];
en[i]=en[i-1]+dp2[i];
}
/*for(int i=0;i<n;i++)
cout<<en[i]<<" ";
cout<<endl;
for(int i=0;i<n;i++)
cout<<st[i]<<" ";
cout<<endl;*/
}
int main()
{
while(scanf("%s",s)==1)
{
n=strlen(s);
manacher();
solve();
sum[n-1]=st[n-1];//st的后缀和
for(int i=n-2;i>=0;i--)
sum[i]=sum[i+1]+st[i];
ll ans=0;
for(int i=0;i<n-1;i++)
{
ans+=en[i]*sum[i+1];
}
printf("%lld\n",ans);
}
return 0;
}
另外这道题更普遍的解法是回文树,但是我不会啊55555...回头有时间好好学习一下,再做一下这道题吧。