先写个代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 998244353;
const int N = 2e5;
ll fact[N + 5],inv[N + 5];
ll n,m,k;
ll ans1,ans2;
ll fpow(ll a,ll b)
{
ll res = 1;
while(b)
{
if(b & 1)res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void init()
{
fact[0] = 1;
for(int i = 1;i <= N;i++)
{
fact[i] = fact[i - 1] * i % mod;
}
inv[N] = fpow(fact[N],mod - 2);
for(int i = N - 1;i > 0;i--)
{
inv[i] = inv[i + 1] * (i + 1) % mod;
}
inv[0] = 1;
}
ll C(ll a,ll b)
{
//if(a > b)return 1;
return fact[a] * inv[b] % mod * inv[a - b] % mod;
}
int main(int argc,char *argv[])
{
init();
cin >> n >> m >> k;
if(k == 0)
{
if(m == 0)
{
printf("1");
}
else printf("0");
return 0;
}
for(int i = 1;1ll * i * k <= m;i++)
{
if(i & 1)
{
ans1 += C(n - m + 1,i) * C(n - i * k,n - m) % mod;
ans1 %= mod;
}
else
{
ans1 -= C(n - m + 1,i) * C(n - i * k,n - m) % mod;
ans1 += mod;
ans1 %= mod;
}
}
k++;
for(int i = 1;1ll * i * k <= m;i++)
{
if(i & 1)
{
ans2 += C(n - m + 1,i) * C(n - i * k,n - m) % mod;
ans2 %= mod;
}
else
{
ans2 -= C(n - m + 1,i) * C(n - i * k,n - m) % mod;
ans2 += mod;
ans2 %= mod;
}
}
cout << (ans1 - ans2 + mod) % mod << endl;
return 0;
}