一些废话:
写完上一题(分治FFT)之后记起之前牛客写的这道题,题解是用启发式NTT,每次选两个项数最小的合并,当时没有整理NTT的模板所以就先mark了,今天受到分治FFT的启发,想用分治NTT来解决这道题。
复杂度分析:一共会分出log(n)层,每层的项数总和数量级是
Σ
a
i
\Sigma a_i
Σai,ntt复杂度nlogn,所以总复杂度是
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n).
这题也是用生成函数,每一场的生成函数为:
(
2
b
i
−
1
)
+
∑
i
=
1
a
i
C
a
i
i
x
i
(2^{b_i}-1)+\sum_{i=1}^{a_i}C_{a_i}^ix^i
(2bi−1)+∑i=1aiCaiixi
把每一场的函数乘起来之后
x
i
x^i
xi的系数就是恰好取i个的答案。
(一定要好好审题
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 2e5 + 50;
const ll mod = 998244353;
ll qm(ll a, ll b){ll res = 1; while(b){if(b&1) res = res*a%mod; a = a*a%mod; b >>= 1; } return res; }
ll mulwn[maxn<<2], invwn[maxn<<2];
void INIT(){
for(ll i=1;i < maxn*4;i<<=1) mulwn[i]=qm(3,(mod-1)/i);
for(ll i=1;i < maxn*4;i<<=1) invwn[i]=qm(mulwn[i],mod-2);
}
struct FFT {
ll n, m, rev[maxn << 2];
ll a[maxn << 2], b[maxn << 2];
void init(int len) {
for (n = 1, m = 0; n < len + len; n <<= 1, m++);
for (int i = 0; i < n; ++i) {
rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (m - 1);
a[i] = 0;
b[i] = 0;
}
}
void ntt(ll *a, int f) {//
for (int i = 0; i < n; ++i)if (i < rev[i])swap(a[i], a[rev[i]]);
for (int k = 2; k <= n; k <<= 1) {
ll wn=(f>0)?mulwn[k]:invwn[k];
int mid = k>>1;
for(int i = 0; i < n; i += k){
ll w = 1;
for(int j = 0; j < mid; ++j, w = w*wn%mod){
ll temp = (w*a[i+j+mid])%mod;
a[i+j+mid] = (a[i+j]-temp+mod)%mod;
a[i+j] = (a[i+j]+temp)%mod;
}
}
}
return;
}
void Calculate() {
ntt(a, 1); ntt(b, 1);
for (int i = 0; i < n; ++i)a[i] = a[i]*b[i]%mod;//记得取模
ntt(a, -1);
ll invl = qm(n, mod-2);
for(int i = 0; i < n; ++i){
a[i] = a[i]*invl%mod;
}
}
} F;
int n;
ll p[maxn], inv[maxn];
ll a[maxn], b[maxn];
ll p2[maxn];
void init()
{
cin>>n;
for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
for(int i = 1; i <= n; ++i) scanf("%lld", &b[i]);
}
ll Com(int n, int m){
ll ans = p[n]*inv[m]%mod*inv[n-m]%mod;
return ans;
}
vector<ll> w[maxn<<2];
void sol(int rt, int l, int r){
w[rt].clear();
if(l == r){
w[rt].push_back( (p2[b[l]]-1+mod)%mod );
for(int i = 1; i <= a[l]; ++i){
w[rt].push_back(Com(a[l], i));
}
return;
}
int mid = (l+r)>>1;
sol(rt<<1, l, mid); sol(rt<<1|1, mid+1, r);
int len = w[rt<<1].size() + w[rt<<1|1].size()-1;
F.init(len);
for(int i = 0; i < w[rt<<1].size(); ++i) F.a[i] = w[rt<<1][i];
for(int i = 0; i < w[rt<<1|1].size(); ++i) F.b[i] = w[rt<<1|1][i];
F.Calculate();
for(int i = 0; i < len; ++i){
w[rt].push_back(F.a[i]%mod);
}
}
int main(){
p2[0] = p[0] = inv[0] = 1;
for(int i = 1; i < maxn; ++i) p[i] = p[i-1]*i%mod, inv[i] = qm(p[i], mod-2), p2[i] = p2[i-1]*2%mod;
INIT();
init();
sol(1, 1, n);
for(int i = 0; i < w[1].size(); ++i) {
if(i > 0) printf(" ");
printf("%lld", (w[1][i]+mod)%mod);
}printf("\n");
}