题目链接:
题目大意:给出一个多项式$m+1$个点值$a_{0},a_{1}...a_{m}$(其中$f(i)=a_{i}$),并给出两个数$n,x$,求$Q(f,n,x)=\sum\limits_{k=0}^{n}f(k)C_{n}^{k}x^k(1-x)^{n-k}mod998244353$的值。
当$f(x)=1$时,$Q=\sum\limits_{i=0}^{n}C_{n}^{i}k^i(1-k)^{n-i}$,根据二项式定理可知这个式子结果为$1$。
当$f(x)=x$时,$Q=\sum\limits_{i=0}^{n}i\frac{n!}{i!(n-i)!}k^i(1-k)^{n-i}$
$Q=\sum\limits_{i=0}^{n}nk\frac{(n-1)!}{(i-1)!(n-i)!}k^{i-1}(1-k)^{n-i}$
$Q=nk\sum\limits_{i=0}^{n}\frac{(n-1)!}{(i-1)!(n-i)!}k^{i-1}(1-k)^{n-i}$
$Q=nk\sum\limits_{i=0}^{n}C_{n-1}^{i-1}k^{i-1}(1-k)^{n-i}$
根据二项式定理可知,$Q=nk$。
进一步可以发现当$f(x)=x^{\underline{d}}$时,$Q=n^{\underline{d}}k^{d}$。其中$x^{\underline{d}}$表示$x$的$d$次下降幂即$x(x-1)(x-2)...(x-d+1)$,也就是$\frac{x!}{(x-d)!}$。推导过程和上面类似。
$Q=\sum\limits_{i=0}^{n}i^{\underline{d}}\frac{n!}{i!(n-i)!}k^i(1-k)^{n-i}$
$Q=\sum\limits_{i=0}^{n}\frac{i!}{(i-d)!}\frac{n!}{i!(n-i)!}k^{i}(1-k)^{n-i}$
$Q=\sum\limits_{i=0}^{n}n^{\underline{d}}k^{d}\frac{(n-d)!}{(i-d)!(n-i)!}k^{i-d}(1-k)^{n-i}$
$Q=n^{\underline{d}}k^{d}\sum\limits_{i=0}^{n}C_{n-d}^{i-d}k^{i-d}(1-k)^{n-i}$
根据二项式定理,后面那个还等于$1$,所以$Q=n^{\underline{d}}k^{d}$。
因为$x^{\underline{i}}$的最高次幂是$x^i$,所以一个$m$次多项式可以写成$f(x)=\sum\limits_{i=0}^{m}a_{i}x^{\underline{i}}$。那么
$Q(f,n,x)=\sum\limits_{i=0}^{m}a_{i}*Q(x^{\underline{i}},n,k)=\sum\limits_{i=0}^{m}a_{i}*n^{\underline{i}}k^i$
现在考虑如何求$a_{i}$,设$a_{i}=\frac{b_{i}}{i!}$,那么$f(x)=\sum\limits_{i=0}^{m}b_{i}\frac{x^{\underline{i}}}{i!}=\sum\limits_{i=0}^{m}b_{i}C_{x}^{i}$
因为我们知道当$x=0,1,2...m$时$f(x)$的值,所以
当$x=0$时,$f(x)=b_{0}$
我们设$\Delta f(x)=f(x+1)-f(x)$(即一阶差分)。
因为$C_{x+1}^{i}-C_{x}^{i}=C_{x}^{i-1}$,所以$\Delta f(x)=\sum\limits_{i=0}^{m}b_{i}C_{x}^{i-1}$。
那么$\Delta f(0)=b_{1}$,由此可以推出$\Delta^{k}f(0)=b_{k}$(即$k$阶差分)。
至此可以得到一个$O(m^2)$的暴力差分做法(实际上是能$AC$的)。
但我们展开$k$阶差分的第一项(即$\Delta^{k}f(0)$)表达式可以发现:
$b_{k}=\sum\limits_{i=0}^{k}(-1)^{k-i}C_{k}^{i}f(i)$
$b_{k}=\sum\limits_{i=0}^{k}(-1)^{k-i}\frac{k!}{i!(k-i)!}f(i)$
$\frac{b_{k}}{k!}=a_{k}=\sum\limits_{i=0}^{k}\frac{(-1)^{k-i}}{(k-i)!}*\frac{f(i)}{i!}$
我们设$F(i)=\frac{f(i)}{i!},G(i)=\frac{(-1)^i}{i!},A(i)=a_{i}$
那么$A(i)=F(i)*G(i)$用$FFT$或$NTT$多项式乘法一下即可将时间复杂度降到$O(mlog_{m})$。
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define mod 998244353
using namespace std;
int n,k,m;
ll ans;
ll res;
int a[80010];
int b[80010];
int inv[80010];
int len;
ll quick(ll x,int y)
{
ll res=1ll;
while(y)
{
if(y&1)
{
res=res*x%mod;
}
y>>=1;
x=x*x%mod;
}
return res;
}
void NTT(int *a,int len,int miku)
{
for(int k=0,i=0;i<len;i++)
{
if(i>k)
{
swap(a[i],a[k]);
}
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int k=2;k<=len;k<<=1)
{
int t=k>>1;
int x=quick(3,(mod-1)/k);
if(miku==-1)
{
x=quick(x,mod-2);
}
for(int i=0;i<len;i+=k)
{
int w=1;
for(int j=i;j<i+t;j++)
{
int tmp=1ll*a[j+t]*w%mod;
a[j+t]=(a[j]-tmp+mod)%mod;
a[j]=(a[j]+tmp)%mod;
w=1ll*w*x%mod;
}
}
}
if(miku==-1)
{
for(int i=0,t=quick(len,mod-2);i<len;i++)
{
a[i]=1ll*a[i]*t%mod;
}
}
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
for(int i=0;i<=m;i++)
{
scanf("%d",&a[i]);
}
len=1;
while(len<=(m<<1))
{
len<<=1;
}
inv[0]=inv[1]=1;
for(int i=2;i<=m;i++)
{
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
for(int i=1;i<=m;i++)
{
inv[i]=1ll*inv[i]*inv[i-1]%mod;
}
for(int i=0;i<=m;i++)
{
a[i]=1ll*a[i]*inv[i]%mod;
b[i]=(i&1)?mod-inv[i]:inv[i];
}
NTT(a,len,1);
NTT(b,len,1);
for(int i=0;i<len;i++)
{
a[i]=1ll*a[i]*b[i]%mod;
}
NTT(a,len,-1);
res=1ll;
for(int i=0;i<=m;i++)
{
ans=(ans+res*a[i])%mod;
res=res*k%mod*(n-i)%mod;
}
printf("%lld",ans);
}