牛客练习赛50 F tokitsukaze and Another Protoss and Zerg(分治NTT)

一些废话:
写完上一题(分治FFT)之后记起之前牛客写的这道题,题解是用启发式NTT,每次选两个项数最小的合并,当时没有整理NTT的模板所以就先mark了,今天受到分治FFT的启发,想用分治NTT来解决这道题。
复杂度分析:一共会分出log(n)层,每层的项数总和数量级是 Σ a i \Sigma a_i Σai,ntt复杂度nlogn,所以总复杂度是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n).
这题也是用生成函数,每一场的生成函数为: ( 2 b i − 1 ) + ∑ i = 1 a i C a i i x i (2^{b_i}-1)+\sum_{i=1}^{a_i}C_{a_i}^ix^i (2bi1)+i=1aiCaiixi
把每一场的函数乘起来之后 x i x^i xi的系数就是恰好取i个的答案。
一定要好好审题

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 2e5 + 50;
const ll mod = 998244353;
ll qm(ll a, ll b){ll res = 1; while(b){if(b&1) res = res*a%mod; a = a*a%mod; b >>= 1; } return res; }
ll mulwn[maxn<<2], invwn[maxn<<2];
void INIT(){
    for(ll i=1;i < maxn*4;i<<=1) mulwn[i]=qm(3,(mod-1)/i);
    for(ll i=1;i < maxn*4;i<<=1) invwn[i]=qm(mulwn[i],mod-2);
}
struct FFT {
    ll n, m, rev[maxn << 2];
    ll a[maxn << 2], b[maxn << 2];
    void init(int len) {
        for (n = 1, m = 0; n < len + len; n <<= 1, m++);
        for (int i = 0; i < n; ++i) {
            rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (m - 1);
            a[i] = 0;
            b[i] = 0;
        }
    }
    void ntt(ll *a, int f) {//
        for (int i = 0; i < n; ++i)if (i < rev[i])swap(a[i], a[rev[i]]);

        for (int k = 2; k <= n; k <<= 1) {
            ll wn=(f>0)?mulwn[k]:invwn[k];
            int mid = k>>1;
            for(int i = 0; i < n; i += k){
                ll w = 1;
                for(int j = 0; j < mid; ++j, w = w*wn%mod){
                    ll temp = (w*a[i+j+mid])%mod;
                    a[i+j+mid] = (a[i+j]-temp+mod)%mod;
                    a[i+j] = (a[i+j]+temp)%mod;
                }
            }
        }
        return;
    }
    void Calculate() {
        ntt(a, 1); ntt(b, 1);
        for (int i = 0; i < n; ++i)a[i] = a[i]*b[i]%mod;//记得取模
        ntt(a, -1);
        ll invl = qm(n, mod-2);
        for(int i = 0; i < n; ++i){
            a[i] = a[i]*invl%mod;
        }
    }
} F;
int n;
ll p[maxn], inv[maxn];
ll a[maxn], b[maxn];
ll p2[maxn];
void init()
{
    cin>>n;
    for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
    for(int i = 1; i <= n; ++i) scanf("%lld", &b[i]);
}
ll Com(int n, int m){
    ll ans = p[n]*inv[m]%mod*inv[n-m]%mod;
    return ans;
}
vector<ll> w[maxn<<2];
void sol(int rt, int l, int r){
    w[rt].clear();
    if(l == r){
        w[rt].push_back( (p2[b[l]]-1+mod)%mod );
        for(int i = 1; i <= a[l]; ++i){
            w[rt].push_back(Com(a[l], i));
        }
        return;
    }
    int mid = (l+r)>>1;
    sol(rt<<1, l, mid); sol(rt<<1|1, mid+1, r);
    int len = w[rt<<1].size() + w[rt<<1|1].size()-1;
    F.init(len);
    for(int i = 0; i < w[rt<<1].size(); ++i) F.a[i] = w[rt<<1][i];
    for(int i = 0; i < w[rt<<1|1].size(); ++i) F.b[i] = w[rt<<1|1][i];
    F.Calculate();
    for(int i = 0; i < len; ++i){
        w[rt].push_back(F.a[i]%mod);
    }
}
int main(){
    p2[0] = p[0] = inv[0] = 1;
    for(int i = 1; i < maxn; ++i) p[i] = p[i-1]*i%mod, inv[i] = qm(p[i], mod-2), p2[i] = p2[i-1]*2%mod;
    INIT();
    init();
    sol(1, 1, n);
    for(int i = 0; i < w[1].size(); ++i) {
        if(i > 0) printf(" ");
        printf("%lld", (w[1][i]+mod)%mod);
    }printf("\n");
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值