bzoj4503 两个串

版权声明:虽然是个蒟蒻但是转载还是要说一声的哟 https://blog.csdn.net/jpwang8/article/details/79953814

Description


兔子们在玩两个串的游戏。给定两个字符串S和T,兔子们想知道T在S中出现了几次,
分别在哪些位置出现。注意T中可能有“?”字符,这个字符可以匹配任何字符。

S下标从0开始。
S 长度不超过 10^5, T 长度不会超过 S。 S 中只包含小写字母, T中只包含小写字母和“?”

来自 https://www.lydsy.com/JudgeOnline/problem.php?id=4503

Solution


原来fft还能这么用

我们定义两个长度为n的字符串距离为i=1ns2[i](s1[i]s2[i])2
其中当s2=’?’时为0,这样就能保证这是通配符惹

拆一下可以发现两个项可以FFT,另一个项直接求就好了

Code


#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <complex>

typedef std:: complex <double> com;
typedef double db;

const db pi=3.1415926535897932;
const int N=300005;

char str1[N],str2[N];

int a[N],b[N],rev[N],ans[N];

com c[N],d[N],e[N],f[N];

void FFT(com *a,int len,db f) {
    for (int i=0;i<len;i++) if (i<rev[i]) std:: swap(a[i],a[rev[i]]);
    for (int i=1;i<len;i*=2) {
        com wn(cos(pi/i),f*sin(pi/i));
        for (int j=0;j<len;j+=i*2) {
            com w(1,0);
            for (int k=0;k<i;k++) {
                com u=a[j+k],v=w*a[j+k+i];
                a[j+k]=u+v; a[j+k+i]=u-v;
                w*=wn;
            }
        }
    }
    if (f==-1) for (int i=0;i<len;i++) a[i]/=len;
}

int main(void) {
    scanf("%s",str1+1); scanf("%s",str2+1);
    int len1=strlen(str1+1),len2=strlen(str2+1),tmp=0;
    for (int i=1;i<=len1;i++) a[i]=str1[i]-'a'+1;
    for (int i=1;i<=len2;i++) b[len2-i+1]=(str2[i]=='?')?(0):(str2[i]-'a'+1);

    int len,lg; for (len=1,lg=0;len<=len1*2;len*=2,lg++);
    for (int i=0;i<=len;i++) rev[i]=(rev[i/2]/2)|((i&1)<<(lg-1));

    for (int i=1;i<=len2;i++) tmp+=b[i]*b[i]*b[i];
    for (int i=1;i<=len1;i++) c[i]=a[i]*a[i];
    for (int i=1;i<=len2;i++) d[i]=b[i];

    for (int i=1;i<=len1;i++) e[i]=a[i];
    for (int i=1;i<=len2;i++) f[i]=b[i]*b[i]*2;

    FFT(c,len,1); FFT(d,len,1); FFT(e,len,1); FFT(f,len,1);
    for (int i=0;i<len;i++) c[i]*=d[i],e[i]*=f[i];
    FFT(c,len,-1); FFT(e,len,-1);

    for (int i=0;i<len;i++) ans[i]=(int)(0.1+c[i].real())-(int)(0.1+e[i].real());
    int prt=0;
    for (int i=len2+1;i<=len1+1;i++) if (ans[i]+tmp==0) prt++;
    printf("%d\n", prt);
    for (int i=len2+1;i<=len1+1;i++) if (ans[i]+tmp==0) printf("%d\n", i-len2-1);
    return 0;
}
阅读更多

没有更多推荐了,返回首页