区间回文数量
链接:CF245H
题意:
- 给定一个长度为 n n n的字符串,有 q q q次询问,每次询问求出 l − r l-r l−r区间内回文的数量
- 数据范围 1 < = n < = 5000 q ( 1 < = q < = 1 0 6 ) 1<=n<=5000 q(1<=q<=10^6) 1<=n<=5000q(1<=q<=106)
- 例如 a b a aba aba形成的回文有: a , b , a , a b a a,b,a,aba a,b,a,aba四个回文串
分析:
- 这是一道区间DP的问题也可以用PAM实现,这里考虑区间DP,定义状态 d p [ l ] [ r ] : dp[l][r]: dp[l][r]:区间 l − r l-r l−r内回文的数量。
- 如果用常规的区间DP进行状态转移会出现问题,对于区间 l − r : d p [ l ] [ r ] = d p [ l ] [ k ] + d p [ k + 1 ] [ r ] l-r:dp[l][r]=dp[l][k]+dp[k+1][r] l−r:dp[l][r]=dp[l][k]+dp[k+1][r]通过枚举 k k k来进行转移会出现重复枚举的情况,而且时间复杂度也是 n 3 n^3 n3。
- 这里可以用容斥原理,对于区间 l − r l-r l−r:有 d p [ l ] [ r ] = d p [ l + 1 ] [ r ] + d p [ l ] [ r − 1 ] − d p [ l + 1 ] [ r − 1 ] + p [ l ] [ r ] , p [ l ] [ r ] dp[l][r]=dp[l+1][r]+dp[l][r-1]-dp[l+1][r-1]+p[l][r],p[l][r] dp[l][r]=dp[l+1][r]+dp[l][r−1]−dp[l+1][r−1]+p[l][r],p[l][r]代表区间 l − r l-r l−r是否为回文串。通过这个转移方程可以只枚举左右端点就能 n 2 n^2 n2的求出所有状态。
- 注意转移方程是对于 r − l + 1 > = 3 r-l+1>=3 r−l+1>=3的区间适用的,因此需要预处理出来长度为 1 1 1和 2 2 2的区间,判断区间 l − r l-r l−r是否为回文串可以预处理也可以用记忆化搜索
#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#include<stack>
#define x first
#define y second
#define int long long
using namespace std;
typedef long long ll;
const int N=5e3+10,M=60;
int f[N][N];//l~r区间内回文的数量
int n,m;
char s[N];
int p[N][N];
bool Palin(int l,int r)
{
if(p[l][r]!=-1) return p[l][r];
if(s[l]!=s[r])
{
p[l][r]=0;
return 0;
}
else
{
p[l+1][r-1]=Palin(l+1,r-1);
return p[l+1][r-1];
}
}
void solve()
{
cin>>s+1;
n=strlen(s+1);
cin>>m;
memset(p,-1,sizeof p);
for(int i=1;i<n;i++)
{
f[i][i]=1;
p[i][i]=1;
if(s[i]==s[i+1])
{
f[i][i+1]=3;
p[i][i+1]=1;
}
else f[i][i+1]=2,p[i][i+1]=0;
}
f[n][n]=p[n][n]=1;
for(int len=3;len<=n;len++)
{
for(int l=1;l+len-1<=n;l++)
{
int r=l+len-1;
p[l][r]=Palin(l,r);
f[l][r]=f[l+1][r]+f[l][r-1]-f[l+1][r-1]+p[l][r];
}
}
// cout<<p[2][4]<<endl;
// cout<<f[2][3]<<" "<<f[3][4]<<" "<<f[3][3]<<endl;
// cout<<f[1][3]<<" "<<f[2][4]<<endl;
while(m--)
{
int l,r;
scanf("%d%d",&l,&r);
printf("%d\n",f[l][r]);
//cout<<f[l][r]<<endl;
}
}
signed main()
{
int t;
t=1;
//cin>>t;
for(int i=1;i<=t;i++)
{
solve();
}
return 0;
}