中超7 hdu 7055 Yiwen with Sqc(NTT板子,FFT加法,加减法转换乘法,思维,巧妙转换)

53 篇文章 0 订阅
49 篇文章 1 订阅

hdu 7055 Yiwen with Sqc

本篇题解非正解(超时了)正解传送门

题意:

​ 字符串的子串当中每个字母的出现个数(用 s [ ′ a ′ ] [ l , r ] s['a'] [l, r] s[a][l,r] 来表示),要求的就是每个字母的所有 s [ l , r ] s[l, r] s[l,r] 的平方和

分析:

要求所有的子串, C n 2 C_{n}^2 Cn2​​ ,暴力的话是 O ( n 2 ) O(n^2) O(n2) 的算法,优化 O ( n 2 ) O(n^2) O(n2) 的算法我最近刚学了一个 F F T FFT FFT (还热乎着呢…),那么此题能否用 F F T FFT FFT 优化呢?

众所周知, F F T FFT FFT 是优化多项式乘法的,考虑如何将这道题转换成多项式乘法的形式

这道题是让求任意区间的子串当中某一字母的个数的平方和,求区间和要用前缀和的形式,然后再 s [ r ] − s [ l − 1 ] s[r] - s[l-1] s[r]s[l1] ,便是区间 [ l , r ] [l, r] [l,r] 的贡献值

等会,我发现了什么!!! s [ r ] − s [ l − 1 ] s[r] - s[l-1] s[r]s[l1] ,这是啥,敲重点,这是差值欸

那么,差值又能干什么

学过FFT的都知道,差值(也就是减法)能转换成加法,然后加法又能转换成乘法,系数相乘就能用 F F T FFT FFT 优化了

那么这题就要去考虑什么作为多项式的次数,什么又作为多项式的系数,题目要计算的是每一差值出现的次数,差值是由两个不同的 s [ i ] s[i] s[i] 产生的,显然 s [ i ] s[i] s[i] 是作为次数的,系数便是 s [ i ] s[i] s[i] 的个数

剩下便是FFT的板子了,还是有几个坑点要注意的

Code:

代码是正解,但是标程要 O ( n ) O(n) O(n) 的算法,可怜我改了两三个小时的代码 T_T )

#include<bits/stdc++.h>
#define int long long 
using namespace std;

const int N = (1<<20)+5, mo=998244353;
const double PI=acos(-1);
struct Complex
{
    double x, y;
    Complex operator+(const Complex &o) const{return{x+o.x,y+o.y};}  
    Complex operator-(const Complex &o) const{return{x-o.x,y-o.y};}  
    Complex operator*(const Complex &o) const{return{x*o.x-y*o.y,x*o.y+y*o.x};} 
}A[N],B[N];
int rev[N];
void init(int k)
{
    int s=1<<k;
    for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
void fft(Complex *a,int n,int inv)
{
    for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
    for(int len=1;len<n;len<<=1)
    {
        Complex Wn=Complex({cos(PI/len),inv*sin(PI/len)});
        for(int i=0;i<n;i+=len*2)
        {
            Complex w=Complex({1,0});
            for(int j=0;j<len;j++,w=w*Wn)
            {
                Complex x=a[i+j],y=w*a[i+j+len];
                a[i+j]=x+y,a[i+j+len]=x-y;
            }
        }
    }
    if(inv==-1) for(int i=0;i<n;i++) A[i].x = A[i].x/n+0.5; // 精度
}
// --------FFT
const int M=1e5+5;
char s[M];
int dp[33][M];
signed main()
{
    int t;
    scanf("%lld",&t);
    while(t--)
    {
        scanf("%s",s+1);
        int n=strlen(s+1);
        //memset(dp,0,sizeof(dp));
        for(int i=1;i<=n;i++)
        {
            for(int j=0;j<26;j++)
            {
                dp[j][i] = dp[j][i-1];
            } 
            dp[s[i]-'a'][i]++;
        }
        int ans=0,sum=2,k=1;
        while(sum <= n+n) k++,sum <<= 1;
        init(k);
        for(int num=0;num<26;num++)
        {
            //memset(A,0,sizeof(A));
            //memset(B,0,sizeof(B));
            // 无力的挣扎,试图省点时间
            for(int i=0;i<sum;i++) A[i]=B[i]=Complex({0,0});
            B[n].x=1;
            // 这步de了一小时才想明白,傻了
            for(int j=1;j<=n;j++)
            {
                A[dp[num][j]].x++;
                B[n-dp[num][j]].x++;
            }
            if((int)A[0].x==n) continue; // 也是为了...
            //if(B[n].x < 1) B[n].x=1;
            //for(int i=0;i<=n;i++) cout<<A[i].x<<' '; cout<<endl;
            //for(int i=0;i<=n;i++) cout<<B[i].x<<' '; cout<<endl;
            fft(A,sum,1); fft(B,sum,1);
            for(int i=0;i<sum;i++) A[i] = A[i]*B[i];
            fft(A,sum,-1);
            //for(int i=1;i<=n;i++) cout<<(int)A[i+n].x<<' '; cout<<endl;
            for(int i=1;i<=n;i++) ans = (ans+(int)A[i+n].x*i*i)%mo;
            //cout<<char('a'+num)<<' '<<ans<<endl;
        }
        printf("%lld\n",ans);
    }

    return 0;
}

B [ n ] . x = 1 B[n].x = 1 B[n].x=1 为啥要有这个初始化,因为所有的区间 [ 0 , i ] [0, i] [0,i] 都要考虑一下,不能忘了,就是这个 0 0 0 ,用了各种办法来处理都不行,比赛的时候思路太乱

能过样例的那一刻巨开心,结果…

在这里插入图片描述

这题加深了我对FFT问题的理解,主要是如何转换方面,多项式间的乘法,首先要先能转换到乘法或加法,然后就是考虑什么作为多项式次数,什么作为系数,还有就是边界问题的处理。

然后这题我还信了网上的鬼话, N T T NTT NTT 跑的比 F F T FFT FFT 快,下面是 N T T NTT NTT 的代码(其实速度上没多大区别)

#include<bits/stdc++.h>
#define int long long 
using namespace std;

const int N = (1<<20)+5, mo=998244353;

int rev[N];
void init(int k)
{
    int s=1<<k;
    for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
int binpow(int a,int b)
{
    int ans=1;
    while(b)
    {
        if(b&1) ans = ans*a%mo;
        a = a*a%mo;
        b >>= 1;
    }
    return ans;
}
void ntt(int *a, int n, int inv)
{
    for(int i = 0; i < n; i++) 
        if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int mid = 1; mid < n; mid <<= 1) {    
        int Wn = binpow(3, (mo-1)/(mid<<1));
        if(inv == -1) Wn = binpow(Wn, mo-2);
        for(int j = 0; j < n; j += (mid << 1)) {
            int w = 1;
            for(int k = 0; k < mid; k++, w = (w * Wn) % mo) {
                 int x = a[j + k], y = w * a[j + k + mid] % mo;
                 a[j + k] = (x + y) % mo,
                 a[j + k + mid] = (x - y + mo) % mo;
            }
        }
    }
    if(inv == -1)
    {
        int fg=binpow(n, mo-2);
        for(int i=0;i<n;i++) a[i] = a[i]*fg%mo;
    }
}
const int M=1e5+5;
char s[M];
int dp[33][M], a[N], b[N];
signed main()
{
    //freopen("1.in","r",stdin);
    //freopen("1.out","w",stdout);
    int t;
    scanf("%lld",&t);
    while(t--)
    {
        scanf("%s",s+1);
        int n=strlen(s+1);
        //memset(dp,0,sizeof(dp));
        for(int i=1;i<=n;i++)
        {
            for(int j=0;j<26;j++)
            {
                dp[j][i] = dp[j][i-1];
            } 
            dp[s[i]-'a'][i]++;
        }
        int ans=0,sum=2,k=1;
        while(sum <= n+n) k++,sum <<= 1;
        init(k);
        for(int num=0;num<26;num++)
        {
            //memset(A,0,sizeof(A));
            //memset(B,0,sizeof(B));
            for(int i=0;i<sum;i++) a[i]=b[i]=0;
            b[n] = 1;
            for(int j=1;j<=n;j++)
            {
                a[dp[num][j]]++;
                b[n-dp[num][j]]++;
            }
            if(a[0]==n) continue;
            //if(B[n].x < 1) B[n].x=1;
            //for(int i=0;i<=n;i++) cout<<A[i].x<<' '; cout<<endl;
            //for(int i=0;i<=n;i++) cout<<B[i].x<<' '; cout<<endl;
            ntt(a,sum,1); ntt(b,sum,1);
            for(int i=0;i<sum;i++) a[i] = a[i]*b[i];
            ntt(a,sum,-1);
            //for(int i=1;i<=n;i++) cout<<(int)A[i+n].x<<' '; cout<<endl;
            for(int i=1;i<=n;i++) ans = (ans+a[i+n]*i*i)%mo;
            //cout<<char('a'+num)<<' '<<ans<<endl;
        }
        printf("%lld\n",ans);
    }

    return 0;
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yezzz.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值