题目链接:https://www.luogu.com.cn/problem/P6216
题意:
题目描述
对于一对字符串s1,s2,若s1的长度为奇数的子串 (l,r),满足 (l,r) 是回文的,那么 s1 的“分数”会增加s2 在 (l,r) 中出现的次数。
现在给出一对 (s1,s2),请计算出 s1 的“分数”。
答案对 2^32 取模。
其中1<=|s2|<=|s1|<=3e6。
输入格式
第一行两个整数,n,m,表示 s1 的长度和 s2 的长度。
第二行两个字符串 s1,s2。
输出格式
一行一个整数,表示 s1 的分数。
题解:思路还是很好想到的,但是细节需要注意一下。
首先kmp标记s2在s1中出现的位置(vis标记),然后manacher维护len[i]表示以i为中心的最长回文串的长度(注意abcba,len[3]=2,len[1]=0)。
然后考虑以i为起点的最长回文串中包含了多少s2。方法,[l,r]中所有回文串包含vis的个数(先不考虑s2的长度,即认为其长度为1):l<=i<=r,id表示回文中点。
分为两段[l,id]:。展开即为(表示以i为起点的回文串中包含s2的个数,*(i-l+1)是因为既然以i为起点包含s2的个数,那么当下表>=l且<=i的都包含s2的个数(累加的)。
[id+1,r]:。展开即为这里表示以i为终点的回文串中包含s2的个数,*(r-i+1)与上面一个道理。
所以只需要维护一下 vis[i]*i 和 vis[i] 的前缀和就可以O(1)求出以 id 为回文中心的所有回文串中包含s2的个数的和了。
注意:
1.考虑s2的长度。那么l,r就不是id-len[id]和id+len[id]了。可以发现l不变,但是r就会改变为i+len[i]-|s2|+1。
令mid=(l+r)/2,那么[l,mid]为起点必然<=id。[mid+1,r]+|s2|-1为终点必然>=id。一一对应了原[l,r]中某点为起点,某点为终点(去除了长度小于|s2|的回文)。
总之,以mid区分就是因为mid之前的数一定可以为起点,mid+1+|s2|-1之后的数一定可以为终点,不重叠(<=id时只当是起点,否则视为终点,刚好连接),并且一定满足回文长度>=|s2|(证明:,注意这里代入得到:。显然成立,得证!至于以mid为起点就更容易了,就不说了)。就这样。比赛时是不可能让我来证明的,就需要我敏锐的嗅觉???
2.kmp已经习惯了,理解了从下标0开始,其实理解之后换成1也很容易,但还是尽量先以0开始,最后处理完之后再一下子变为以1开始就好(不熟练以1开始就这样)。
manacher也是这样:
另外还需要注意,在循环的时候需要限制i+len[i]<=n&&i-len[i]>=1(否则就应该令s[0]='$'一个不会出现的字符)
manacher示例:
mx=0,id=0;
rep(i,0,n-1){
if(mx>i) len[i]=min(len[2*id-i],len[id]+id-i);
else len[i]=1;
while(s1[i-len[i]]==s1[i+len[i]]&&i-len[i]>=0&&i+len[i]<n) ++len[i];
if(i+len[i]>mx) mx=i+len[i],id=i;
}
代码:
#include <bits/stdc++.h>
#define ll long long
#define ld double
#define pi acos(-1)
#define pb push_back
#define mst(a, i) memset(a, i, sizeof(a))
#define pll pair<ll, ll>
#define fi first
#define se second
#define mp(x,y) make_pair(x,y)
#define rep(i,a,n) for(ll i=a;i<=n;i++)
#define per(i,n,a) for(ll i=n;i>=a;i--)
#define dbg(x) cout << #x << "===" << x << endl
#define dbgg(l,r,x) for(ll i=l;i<=r;i++) cout<<x[i]<<" ";cout<<"<<<"<<#x;cout<<endl
using namespace std;
template<class T>void read(T &x){T res=0,f=1;char c=getchar();while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}while(isdigit(c)){res=(res<<3)+(res<<1)+c-'0';c=getchar();}x=res*f;}
inline void print(ll x){if(x<0){putchar('-');x=-x;}if(x>9)print(x/10);putchar(x%10+'0');}
const ll maxn = 6e6 + 10;
ll mod = 1;
ll n,m,a[maxn],vis[maxn],sum[maxn],isum[maxn],len[maxn];
string s1,s2;
int main() {
ll _s = 1;
//read(_s);
//freopen("testdata.in","r",stdin);
//freopen("testout.out","w",stdout);
for (ll _=1;_<=_s;_++) {
mod<<=32;
read(n),read(m);
cin>>s1>>s2;
string c=s2+"#"+s1;
// dbg(c);
ll t=0,lc=c.size();
a[0]=0;
rep(i,1,lc-1){
while(t&&c[t]!=c[i]) t=a[t-1];
t+=(c[t]==c[i]);
a[i]=t;
if(a[i]==m) vis[i-2*m+1]=1;
// dbg(i-2*m);
}//?以上kmp
// per(i,n,1) vis[i]=vis[i-1];
//?一下manacher
ll mx=0,id=0;
rep(i,0,n-1){
if(mx>i) len[i]=min(len[2*id-i],len[id]+id-i);
else len[i]=1;
while(s1[i-len[i]]==s1[i+len[i]]&&i-len[i]>=0&&i+len[i]<n) ++len[i];
if(i+len[i]>mx) mx=i+len[i],id=i;
}
per(i,n,1) len[i]=len[i-1]-1;
//?以上manacher
rep(i,1,n) sum[i]=sum[i-1]+vis[i],isum[i]=isum[i-1]+vis[i]*i;
ll ans=0,l,r,mid;
rep(i,1,n){
l=i-len[i],r=i+len[i]-m+1;
if(l>r) continue;
mid=(l+r)>>1;
ans+=(isum[mid]-isum[l-1])-(sum[mid]-sum[l-1])*(l-1);ans%=mod;
if(mid!=r) ans+=(sum[r]-sum[mid])*(r+1)-(isum[r]-isum[mid]);ans%=mod;
}
cout<<ans<<endl;
}
return 0;
}