原题链接:https://www.lydsy.com/JudgeOnline/problem.php?id=4503
两个串
Description
兔子们在玩两个串的游戏。给定两个字符串S和T,兔子们想知道T在S中出现了几次,分别在哪些位置出现。注意T中可能有“?”字符,这个字符可以匹配任何字符。
Input
两行两个字符串,分别代表S和T
Output
第一行一个正整数k,表示T在S中出现了几次。
接下来k行正整数,分别代表T每次在S中出现的开始位置。按照从小到大的顺序输出,S下标从0开始。
Sample Input
bbabaababaaaaabaaaaaaaabaaabbbabaaabbabaabbbbabbbbbbabbaabbbababababbbbbbaaabaaabbbbbaabbbaabbbbabab
a?aba?abba
Sample Output
0
HINT
S 长度不超过 10^5, T 长度不会超过 S。 S 中只包含小写字母, T中只包含小写字母和“?”
题解
设 n = ∣ S ∣ , m = ∣ T ∣ n=|S|,m=|T| n=∣S∣,m=∣T∣,定义函数 f ( x ) = ∑ i = 0 m − 1 ( S x + i − T i ) 2 × T i f(x)=\sum_{i=0}^{m-1}(S_{x+i}-T_i)^2\times T_i f(x)=∑i=0m−1(Sx+i−Ti)2×Ti,其中小写字母的值直接用 A C S I I ACSII ACSII码,’ ? ? ? ’ 的值为 0 0 0,这样如果 f ( x ) = 0 f(x)=0 f(x)=0,那么说明可以从 x x x开始匹配。
考虑如何计算 f ( x ) f(x) f(x),将 ( S x + i − T i ) 2 × T i (S_{x+i}-T_i)^2\times T_i (Sx+i−Ti)2×Ti展开,得到 S x + i 2 T i − 2 S x + i T i 2 + T i 3 S^2_{x+i}T_i-2S_{x+i}T^2_i+T_i^3 Sx+i2Ti−2Sx+iTi2+Ti3,前两个为乘积形式, ∑ \sum ∑起来是平行的:
这样我们没法快速计算,所以我们把 T T T翻转过来:
运算变成了卷积形式,我们就可以用 F F T \mathcal{FFT} FFT快速计算了。
代码
#include<bits/stdc++.h>
#define db double
using namespace std;
const int M=(1<<18)+5;
const db pi=acos(-1.0);
struct cpx{db x,y;}s[M],t[M],s2[M],t2[M];
cpx operator +(cpx a,cpx b){return (cpx){a.x+b.x,a.y+b.y};}
cpx operator -(cpx a,cpx b){return (cpx){a.x-b.x,a.y-b.y};}
cpx operator *(cpx a,cpx b){return (cpx){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
int n,m,len,mx,rev[M],ans[M],tot,t3;
char ch[M];
void fft(cpx *f,int typ)
{
cpx wn,w,x,y;int i,j,k,mid;
for(i=0;i<mx;++i)if(i<rev[i])swap(f[i],f[rev[i]]);
for(mid=1;mid<mx;mid<<=1)for(j=0,wn=(cpx){cos(pi/mid),typ*sin(pi/mid)};j<mx;j+=mid<<1)
for(k=0,w=(cpx){1,0};k<mid;++k,w=w*wn)x=f[j+k],y=w*f[j+mid+k],f[j+k]=x+y,f[j+mid+k]=x-y;
}
void in()
{
scanf("%s",ch);for(int i=n=strlen(ch)-1;i>=0;--i)s[i].x=ch[i]-'a'+1;
scanf("%s",ch);for(int i=m=strlen(ch)-1;i>=0;--i)t[i].x=ch[i]=='?'?0:ch[i]-'a'+1;
}
void ac()
{
for(mx=1;mx<=n+m;mx<<=1,++len);for(int i=0;i<mx;++i)rev[i]=rev[i>>1]>>1|((i&1)<<(len-1));
reverse(t,t+m+1);
for(int i=0;i<mx;++i)s2[i]=s[i]*s[i],t2[i]=(cpx){2,0}*t[i]*t[i],t3+=(t[i]*t[i]*t[i]).x;
fft(s2,1),fft(t2,1),fft(s,1),fft(t,1);
for(int i=0;i<mx;++i)s2[i]=s2[i]*t[i]-s[i]*t2[i];
fft(s2,-1);
for(int i=0;i<=n-m;++i)if(int(s2[m+i].x/mx+t3+0.5)==0)ans[++tot]=i;
printf("%d\n",tot);for(int i=1;i<=tot;++i)printf("%d\n",ans[i]);
}
int main(){in();ac();}