问题可以转化为:w列,每一列可以放任意高度的柱子。如果当前的柱子比前一列的柱子高度多d,就等于原问题的积木多了d个。
设dp0[i][j]表示从左往右考虑,第i列的高度为j的方案数,dp1[i][j]表示积木数,dp2[i][j]表示每种方案的平方和。
显然有转移\(dp0[i][j]=\sum_{k=1}^hdp0[i][k]\),你也可以写成\(dp0[i][j]=h^i\),不过这都不重要了。
然后有\(dp1[i][j]=\sum_{k=1}^hdp1[i][k]+\sum_{k=1}^jdp0[i][j]*(j-k)\)
以及\(dp2[i][j]=\sum_{k=1}^hdp2[i-1][k]+2*\sum_{k=1}^jdp1[i-1][k]*(j-k)+\sum_{k=1}^jdp0[i-1][k]*(j-k)^2\)
说说这个转移怎么来的。你考虑枚举所有\(i-1\)列的方案,就有\(dp1[i][j]=\sum_{方案}(方案的积木数+j-k)=\sum_{方案}方案的积木数+方案数*d=\sum dp1[i][k]+\sum dp0[i][k]*(j-k)\)
然后\(dp2[i][j]=\sum_{方案}(方案的积木数+j-k)^2=\sum_{方案}(方案的积木数^2+2*方案的积木数*(j-k)+(j-k)^2)=\sum dp2[i][k]+\sum dp1[i][k]*(j-k)+\sum dp0[i][k]*(j-k)^2\)。
最后前缀和优化一下即可。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define fp(i,a,b) for(rg int i=(a),ed=(b);i<=ed;++i)
#define fb(i,a,b) for(rg int i=(a),ed=(b);i>=ed;--i)
using namespace std;
typedef cn int cint;
typedef long long LL;
cint maxn=5010,mod=998244353;
int n,m,f[2][maxn][3],sum[2][maxn][3],sum2[2][maxn][2],sum3[2][maxn];
il void add(int &a,LL b){a+=b;while(a>=mod)a-=mod;}
il int slv(cint &n,cint &m){
memset(f,0,sizeof f);
fp(i,1,m)f[1][i][0]=1,f[1][i][1]=i,f[1][i][2]=i*i;
fp(i,1,m){
sum[1][i][0]=(sum[1][i-1][0]+f[1][i][0])%mod;
sum[1][i][1]=(sum[1][i-1][1]+f[1][i][1])%mod;
sum[1][i][2]=(sum[1][i-1][2]+f[1][i][2])%mod;
sum2[1][i][0]=(sum2[1][i-1][0]+1ll*f[1][i][0]*i)%mod;
sum2[1][i][1]=(sum2[1][i-1][1]+1ll*f[1][i][1]*i)%mod;
sum3[1][i]=(sum3[1][i-1]+1ll*f[1][i][0]*i%mod*i)%mod;
}
rg int now=1,pre=0;
fp(i,2,n){
swap(now,pre);
memset(f[now],0,sizeof f[now]);
memset(sum[now],0,sizeof sum[now]);
memset(sum2[now],0,sizeof sum2[now]);
memset(sum3[now],0,sizeof sum3[now]);
fp(j,1,m){
f[now][j][0]=sum[pre][m][0];
f[now][j][1]=(sum[pre][m][1]+1ll*sum[pre][j-1][0]*j-sum2[pre][j-1][0]+mod)%mod;
f[now][j][2]=(sum[pre][m][2]+2ll*sum[pre][j-1][1]*j-2ll*sum2[pre][j-1][1]%mod+sum3[pre][j-1]+1ll*sum[pre][j-1][0]*j%mod*j-2ll*j*sum2[pre][j-1][0]%mod+mod+mod)%mod;
}
fp(j,1,m){
sum[now][j][0]=(sum[now][j-1][0]+f[now][j][0])%mod;
sum2[now][j][0]=(sum2[now][j-1][0]+1ll*f[now][j][0]*j)%mod;
sum3[now][j]=(sum3[now][j-1]+1ll*f[now][j][0]*j%mod*j)%mod;
sum[now][j][1]=(sum[now][j-1][1]+f[now][j][1])%mod;
sum2[now][j][1]=(sum2[now][j-1][1]+1ll*f[now][j][1]*j)%mod;
sum[now][j][2]=(sum[now][j-1][2]+f[now][j][2])%mod;
}
}
return sum[now][m][2];
}
int main(){
scanf("%d %d",&n,&m);
printf("%d\n",(slv(n,m)-slv(n,m-1)+mod)%mod);
return 0;
}