Description
n
<
=
1
e
6
,
a
i
<
=
1
e
9
n<=1e6,a_i<=1e9
n<=1e6,ai<=1e9,输出
∑
i
=
0
n
f
i
(
n
)
∗
(
1
e
9
+
7
)
i
(
m
o
d
2
32
)
\sum_{i=0}^{n}f_{i}(n)*(1e9+7)^i (mod\ 2^{32})
∑i=0nfi(n)∗(1e9+7)i(mod 232)
Solution
- 完全没有想到可以直接将组合数用多项式表示,来推式子,经典姿势++
- 原式 a i a_i ai的系数显然是 ( 1 − x ) m ( 1 + x ) n − m [ x i ] (1-x)^m(1+x)^{n-m}[x^i] (1−x)m(1+x)n−m[xi]
- 这个式子看起来不好求,我们把它转化一下
( 1 − x ) m ( 1 + x ) n − m (1-x)^m(1+x)^{n-m} (1−x)m(1+x)n−m
= ( 2 − ( 1 + x ) ) m ( 1 + x ) n − m =(2-(1+x))^m(1+x)^{n-m} =(2−(1+x))m(1+x)n−m
= ∑ i = 0 m 2 i C m i ( − 1 ) m − i ( 1 + x ) n − i =\sum_{i=0}^m2^iC_m^i(-1)^{m-i}(1+x)^{n-i} =i=0∑m2iCmi(−1)m−i(1+x)n−i - 因为是
a
a
a点乘
(
1
−
x
)
m
(
1
+
x
)
n
−
m
(1-x)^m(1+x)^{n-m}
(1−x)m(1+x)n−m,相当于是
a
a
a点乘
(
1
+
x
)
n
−
i
(1+x)^{n-i}
(1+x)n−i,即:
a ∗ ( 1 + x ) k = ∑ i = 0 k C k i ∗ a i a*(1+x)^k=\sum_{i=0}^kC_k^i*a_i a∗(1+x)k=i=0∑kCki∗ai - 先卷积算出出点积的结果,再带回去卷积一次即可。
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define maxn 3000005
#define ll long long
#define uint unsigned int
#define mo 998244353
using namespace std;
int n,m,i,j,k,lim,bt[maxn];
ll A[maxn],S,M,B,fct[maxn],invf[maxn],_2[maxn];
ll a[maxn],b[maxn],c[maxn];
ll ksm(ll x,ll y){
ll s=1;
for(;y;y/=2,x=x*x%mo) if (y&1)
s=s*x%mo;
return s;
}
void dft(ll *a,int sig){
for(int i=0;i<lim;i++) if (i<bt[i]) swap(a[i],a[bt[i]]);
for(int mid=1;mid<lim;mid<<=1){
ll gn=ksm(3,(mo-1)/(mid<<1));
if (sig<0) gn=ksm(gn,mo-2);
for(int j=0;j<lim;j+=mid<<1){
ll g=1;
for(int k=0;k<mid;k++,g=g*gn%mo){
ll x=a[j+k],y=a[j+k+mid]*g;
a[j+k]=(x+y)%mo,a[j+k+mid]=(x-y)%mo;
}
}
}
}
int main(){
freopen("count.in","r",stdin);
freopen("count.out","w",stdout);
scanf("%d%lld%lld%lld%lld",&n,&A[0],&S,&M,&B);
for(i=1;i<=n;i++) A[i]=((A[i-1]^S)*M+B)%mo;
fct[0]=1;for(i=1;i<=n;i++) fct[i]=fct[i-1]*i%mo;
_2[0]=1;for(i=1;i<=n;i++) _2[i]=_2[i-1]*2%mo;
invf[n]=ksm(fct[n],mo-2);
for(i=n-1;i>=0;i--) invf[i]=invf[i+1]*(i+1)%mo;
for(lim=1;lim<=2*n;lim<<=1);
for(i=1;i<lim;i++) bt[i]=(bt[i>>1]>>1)|((i&1)?(lim>>1):0);
for(i=0;i<=n;i++) a[i]=invf[i],b[i]=invf[i]*A[i]%mo;
dft(a,1),dft(b,1);
for(i=0;i<lim;i++) c[i]=a[i]*b[i]%mo;
dft(c,-1);
ll inv=ksm(lim,mo-2);
for(i=0;i<=n;i++) c[i]=c[i]*inv%mo*fct[i]%mo;
memset(a,0,sizeof(a)),memset(b,0,sizeof(b));
for(i=0;i<=n;i++) a[i]=_2[i]*invf[i]%mo*c[n-i]%mo,b[i]=invf[i]*((i&1)?-1:1);
dft(a,1),dft(b,1);
for(i=0;i<lim;i++) c[i]=a[i]*b[i]%mo;
dft(c,-1);
inv=ksm(lim,mo-2);
uint ans=0,mul=1,p=1e9+7;
for(i=0;i<=n;i++,mul=mul*p) {
c[i]=(c[i]*inv%mo*fct[i]%mo+mo)%mo;
ans+=mul*c[i];
}
printf("%u",ans);
}