$\newcommand{\align}[1]{\begin{align*}#1\end{align*}}$题意:对于一个字符串$s$,定义$C(s)$为$s$中(出现次数最多的字母)出现的次数,问长度为$n$,字符集大小为$m$且$C(s)=k$的字符串有多少个
设$f_{i,j,k}$表示字符集大小为$i$,长度为$j$且$C(s)\leq k$的方案数,那么有$\align{f_{i,j,k}=\sum\limits_{l=0}^k\binom jlf_{i-1,j-l,k}}$(枚举最大字符的出现次数$l$,这个字符在$s$中出现的不同方案为$\align{\binom jl}$,剩下字符组成字符串的方案数为$f_{i-1,j-l,k}$)
这个DP式的第三维下标$k$没有变化,不妨删掉这维,并稍微推导一下:
$\align{f_{i,j}&=\sum\limits_{j=0}^k\binom jlf_{i-1,j-l}\\\dfrac{f_{i,j}}{j!}&=\sum\limits_{j=0}^k\dfrac1{l!}\dfrac{f_{i-1,j-l}}{(j-l)!}}$
这是卷积的形式,记$\align{F_i(x)=\sum\limits_{j=0}^k\dfrac{f_{i,j}x^j}{j!}}$,则$\align{F_i(x)=F_{i-1}(x)\left(\sum\limits_{j=0}^k\dfrac1{j!}\right)}$,直接快速幂就可以了,于是$\align{f_{i,j,k}=n!\left[x^n\right]\left(\sum\limits_{j=0}^k\dfrac1{j!}\right)^m}$,答案为$f_{i,j,k}-f_{i,j,k-1}$
#include<stdio.h>
#include<string.h>
const int mod=998244353;
typedef long long ll;
int mul(int a,int b){return a*(ll)b%mod;}
int ad(int a,int b){return(a+b)%mod;}
int de(int a,int b){return(a-b)%mod;}
int pow(int a,int b){
int s=1;
while(b){
if(b&1)s=mul(s,a);
a=mul(a,a);
b>>=1;
}
return s;
}
int rev[131072],N,iN;
void pre(int n){
int i,k;
for(N=1,k=0;N<n;N<<=1)k++;
for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
iN=pow(N,mod-2);
}
void swap(int&a,int&b){a^=b^=a^=b;}
void ntt(int*a,int on){
int i,j,k,t,w,wn;
for(i=0;i<N;i++){
if(i<rev[i])swap(a[i],a[rev[i]]);
}
for(i=2;i<=N;i<<=1){
wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i));
for(j=0;j<N;j+=i){
w=1;
for(k=0;k<i>>1;k++){
t=mul(w,a[i/2+j+k]);
a[i/2+j+k]=de(a[j+k],t);
a[j+k]=ad(a[j+k],t);
w=mul(w,wn);
}
}
}
if(on==-1){
for(i=0;i<N;i++)a[i]=mul(a[i],iN);
}
}
void pow(int*a,int n,int k,int*s){
int i;
s[0]=1;
pre((n+1)<<1|1);
while(k){
ntt(a,1);
if(k&1){
ntt(s,1);
for(i=0;i<N;i++)s[i]=mul(s[i],a[i]);
ntt(s,-1);
for(i=n+1;i<N;i++)s[i]=0;
}
for(i=0;i<N;i++)a[i]=mul(a[i],a[i]);
ntt(a,-1);
for(i=n+1;i<N;i++)a[i]=0;
k>>=1;
}
}
int fac[50010],rfac[50010],a[131072],b[131072];
int solve(int n,int m,int k){
int i;
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
for(i=0;i<=k;i++)a[i]=rfac[i];
pow(a,n,m,b);
return mul(b[n],fac[n]);
}
int main(){
int n,m,k,i;
scanf("%d%d%d",&n,&m,&k);
fac[0]=1;
for(i=1;i<=n;i++)fac[i]=mul(fac[i-1],i);
rfac[n]=pow(fac[n],mod-2);
for(i=n;i>0;i--)rfac[i-1]=mul(rfac[i],i);
printf("%d",(de(solve(n,m,k),solve(n,m,k-1))+mod)%mod);
}