原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ269.html
题目传送门 - UOJ269
题意
有一个多项式函数 $f(x)$,最高次幂为 $x^m$,定义变换 $Q$:
$$Q(f,n,x)=\sum_{k=0}^n f(k)\binom nk x^k(1−x)^{n−k}$$
现在给定函数 $f$ 和 $n,x$,求 $Q(f,n,x)\mod {\rm 998244353}$。
$f(x)$ 由 $0$~$m$ 的点值给出。
$1\leq n\leq 10^9,1\leq m \leq 2\times 10^4, 0\leq a_i,x <998244353$
题解
cly_none 太强了。
考虑一个 $m$ 次多项式 $f(x)$ ,必然可以拆成一堆下降幂的和。(忽略系数)其中,最高次项是 $m$ 次项,所以转成下降幂之后,最高次项就是一个 $m$ 阶下降幂。
对于 $f(x)$ 的某一个下降幂表示,设为 $x^\underline{k}$ ,那么,可以得到:
$$\begin{aligned} & \sum_{i=0}^n i^{\underline k} {n\choose i} x^i (1-x)^{n-i} \\ = & \sum_{i=k}^n i^{\underline k} \frac {n^{\underline k}}{i ^ {\underline k}} {n - k\choose i - k } x^i (1-x)^{n-i} \\ = & n^{\underline k} \sum_{i=k}^n {n - k\choose i - k } x^i (1-x)^{n-i} \\ = & n^{\underline k} x^k \sum_{i=0}^{n-k} {n - k\choose i} x^i (1-x)^{n-k-i} \\ = & n^{\underline k} x^k \end{aligned}$$
于是这样就证明了题目要求的式子是一个关于 $n$ 的 $m$ 次多项式。
于是只需要 FFT 一下,求出 $[0,m]$ 之间的整点的点值,然后插值来求答案。由于这些点值十分特殊,所以可以预处理阶乘来 $O(m)$ 求解。
总的时间复杂度为 $O(m\log m)$ 。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=1<<16,mod=998244353;
int read(){
int x=0;
char ch=getchar();
while (!isdigit(ch))
ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return x;
}
int Pow(int x,int y){
int ans=1;
for (;y;y>>=1,x=1LL*x*x%mod)
if (y&1)
ans=1LL*ans*x%mod;
return ans;
}
int n,m,x,a[N],A[N],B[N];
int Fac[N],Inv[N];
int w[N],R[N];
int C(int n,int m){
if (m>n||m<0)
return 0;
return 1LL*Fac[n]*Inv[m]%mod*Inv[n-m]%mod;
}
void FFT(int a[],int n){
for (int i=0;i<n;i++)
if (R[i]<i)
swap(a[R[i]],a[i]);
for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
for (int i=0;i<n;i+=(d<<1))
for (int j=0;j<d;j++){
int tmp=1LL*w[t*j]*a[i+j+d]%mod;
a[i+j+d]=(a[i+j]+mod-tmp)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
void Mul(int a[],int b[],int m){
int n,d;
for (n=1,d=0;n<=m*2+2;n<<=1,d++);
for (int i=0;i<n;i++)
R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
w[0]=1,w[1]=Pow(3,(mod-1)/n);
for (int i=2;i<n;i++)
w[i]=1LL*w[i-1]*w[1]%mod;
FFT(a,n);
FFT(b,n);
for (int i=0;i<n;i++)
a[i]=1LL*a[i]*b[i]%mod;
w[0]=1,w[1]=Pow(w[1],mod-2);
for (int i=2;i<n;i++)
w[i]=1LL*w[i-1]*w[1]%mod;
FFT(a,n);
int inv=Pow(n,mod-2);
for (int i=0;i<n;i++)
a[i]=1LL*a[i]*inv%mod;
}
int calc(int n){
int ans=0;
for (int k=0;k<=n;k++)
ans=(1LL*a[k]*C(n,k)%mod*Pow(x,k)%mod*Pow(mod+1-x,n-k)%mod+ans)%mod;
return ans;
}
int main(){
n=read(),m=read(),x=read();
for (int i=0;i<=m;i++)
a[i]=read();
for (int i=Fac[0]=Inv[0]=1;i<=m;i++){
Fac[i]=1LL*Fac[i-1]*i%mod;
Inv[i]=1LL*Inv[i-1]*Pow(i,mod-2)%mod;
}
if (n<=m)
return printf("%d\n",calc(n)),0;
for (int i=0;i<=m;i++){
A[i]=1LL*a[i]*Inv[i]%mod*Pow(x,i)%mod;
B[i]=1LL*Inv[i]*Pow(mod+1-x,i)%mod;
}
Mul(A,B,m);
for (int i=0;i<=m;i++)
A[i]=1LL*A[i]*Fac[i]%mod;
int ans=0;
for (int i=0;i<=m;i++){
int t=1LL*A[i]*Inv[i]%mod*Inv[m-i]%mod;
t=1LL*t*Pow(n+mod-i,mod-2)%mod;
if ((m-i)&1)
t=(mod-t)%mod;
ans=(ans+t)%mod;
}
for (int i=n;i>=n-m;i--)
ans=1LL*ans*i%mod;
printf("%d",ans);
return 0;
}