我会ntt?假的。我会多项式求逆?假的。我会斯特林数?假的。
官方给的题解是cdq+ntt或者多项式求逆,然而我不会gg
还好还有一种可以直接ntt的,给跪orz
Ans=∑i=0n∑j=0iSi,j×2j×(j!)
Si,j=1j!∑k=0j(−1)kCkj(j−k)i
带入得
Ans=∑i=0n∑j=0i∑k=0j(−1)kj!k!(j−k)!(j−k)i×2j
我们改变一下求和顺序,枚举j,得到
Ans=∑j=0n2j×j!∑k=0j(−1)kk!(j−k)!∑i=jn(j−k)i
令
ai=(−1)ii!,bi=∑k=0niki!
则
Ans=∑j=0n2j×j!∑k=0jak∗bj−k
是个卷积,可以直接ntt解决了。至于各种上下界范围都可以取到0~n,因为越界的都得0了。
预处理逆元,阶乘,并利用等比公式来算。
时间复杂度 O(nlogn)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 100010
#define mod 998244353
#define G 3
inline char gc(){
static char buf[1<<16],*S,*T;
if(S==T){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,m,R[N<<2],L=0,inv[N],fac[N],ifac[N],a[N<<2],b[N<<2],ni;
inline int ksm(int x,int k){
int res=1;for(;k;k>>=1,x=(ll)x*x%mod) if(k&1) res=(ll)res*x%mod;return res;
}
inline void ntt(int *a,int f){
for(int i=0;i<n;++i) if(i>R[i]) swap(a[i],a[R[i]]);
for(int i=1;i<n;i+=i){
int wn=ksm(G,f==1?(mod-1)/(2*i):mod-1-(mod-1)/(i*2));
for(int j=0,p=i<<1;j<n;j+=p){
int w=1;
for(int k=0;k<i;++k,w=(ll)w*wn%mod){
int x=a[j+k],y=(ll)a[j+k+i]*w%mod;
a[j+k]=(x+y)%mod;a[j+k+i]=(x-y)%mod;
}
}
}if(f==-1) for(int i=0;i<n;++i) a[i]=(ll)a[i]*ni%mod;
}
int main(){
// freopen("sum.in","r",stdin);
n=read();inv[1]=1;fac[0]=ifac[0]=1;
for(int i=2;i<=n;++i) inv[i]=(ll)inv[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=n;++i) fac[i]=(ll)fac[i-1]*i%mod;
for(int i=1;i<=n;++i) ifac[i]=(ll)ifac[i-1]*inv[i]%mod;
for(int i=0;i<=n;++i) a[i]=(i&1?-ifac[i]:ifac[i]);b[0]=1;b[1]=n+1;
for(int i=2;i<=n;++i) b[i]=((ll)ksm(i,n+1)-1)*inv[i-1]%mod*ifac[i]%mod;
m=n<<1;for(n=1;n<=m;n+=n) L++;ni=ksm(n,mod-2);
for(int i=0;i<n;++i) R[i]=(R[i>>1]>>1)|(i&1)<<L-1;
ntt(a,1);ntt(b,1);
for(int i=0;i<n;++i) a[i]=(ll)a[i]*b[i]%mod;
ntt(a,-1);int bin=1,ans=0;n=m>>1;
for(int i=0;i<=n;++i){
(ans+=(ll)fac[i]*bin%mod*a[i]%mod)%=mod;bin=bin*2%mod;
}if(ans<0) ans+=mod;
printf("%d\n",ans);
return 0;
}