其实就是要计算
∑ni=0(ni)ik ∑ i = 0 n ( i n ) i k
拆一下
∑i(ni)∑kj=0{kj}j!(ij) ∑ i ( i n ) ∑ j = 0 k { j k } j ! ( j i )
画一下柿子
∑kj=0{kj}j!∑i(ni)(ij) ∑ j = 0 k { j k } j ! ∑ i ( i n ) ( j i )
∑j{kj}j!(nj)∑ni=0(n−ji−j) ∑ j { j k } j ! ( j n ) ∑ i = 0 n ( i − j n − j )
∑j{kj}j!(nj)2n−j ∑ j { j k } j ! ( j n ) 2 n − j
用 {kj}=1j!∑ji=0(−1)j−i(ji)ik { j k } = 1 j ! ∑ i = 0 j ( − 1 ) j − i ( i j ) i k NTT预处理第二类斯特林数就可以做了
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn = 610000;
const int mod = 998244353;
inline void add(int &a,const int &b){a+=b;if(a>=mod)a-=mod;}
int pw(int x,ll k)
{
int re=1;
for(;k;k>>=1,x=(ll)x*x%mod) if(k&1ll)
re=(ll)re*x%mod;
return re;
}
int inv(int x){ return pw(x,mod-2); }
int s[maxn],invs[maxn];
void pre()
{
s[0]=1; for(int i=1;i<maxn;i++) s[i]=(ll)s[i-1]*i%mod;
invs[maxn-1]=inv(s[maxn-1]);
for(int i=maxn-2;i>=0;i--) invs[i]=(ll)invs[i+1]*(i+1)%mod;
}
const int g = 3;
int N,ln,A[maxn],B[maxn],id[maxn],w[maxn];
void FNT(int f[],int sig)
{
for(int i=1;i<N;i++) if(i<id[i]) swap(f[i],f[id[i]]);
for(int m=2;m<=N;m<<=1)
{
int t=m>>1,tt=N/m;
for(int i=0;i<t;i++)
{
int wn=sig==1?w[i*tt]:w[N-i*tt];
for(int j=i;j<N;j+=m)
{
int tx=f[j],ty=(ll)f[j+t]*wn%mod;
f[j]=(tx+ty)%mod;
f[j+t]=(tx-ty+mod)%mod;
}
}
}
if(sig==-1)
{
int invn=inv(N);
for(int i=0;i<N;i++) f[i]=(ll)f[i]*invn%mod;
}
}
void cal(int n)
{
N=1,ln=0; while(N<=(n<<1)) N<<=1,ln++;
for(int i=1;i<N;i++) id[i]=id[i>>1]>>1|((i&1)<<ln-1);
w[0]=1; w[1]=pw(g,(mod-1)/N);
for(int i=2;i<=N;i++) w[i]=(ll)w[i-1]*w[1]%mod;
for(int i=0;i<=n;i++) A[i]=(ll)pw(i,n)*invs[i]%mod;
for(int i=0;i<=n;i++) B[i]=(i&1)?mod-invs[i]:invs[i];
FNT(A,1); FNT(B,1);
for(int i=0;i<N;i++) A[i]=(ll)A[i]*B[i]%mod;
FNT(A,-1);
}
int n,K,cnt,inv2=(mod+1)>>1;
int main()
{
//freopen("tmp.in","r",stdin);
//freopen("tmp.out","w",stdout);
pre();
scanf("%d%d",&n,&K); cnt=pw(2,(ll)n*(n-1)/2ll-n+1);
n--;
cal(K);
int ans=0;
for(int j=0,j2=pw(2,n),cj=1;j<=K;j++)
{
add(ans,(ll)A[j]*s[j]%mod*cj%mod*j2%mod);
j2=(ll)j2*inv2%mod;
cj=(ll)cj*(n-j)%mod*inv(j+1)%mod;
}
ans=(ll)ans*cnt%mod*(n+1)%mod;
printf("%d\n",ans);
return 0;
}