目录
题意很简单
n场比赛,赢了m场,最大连赢k场有多少种排列
容斥
一般这种恰好k场的题目会考虑先算出大于等于k场的排列减去大于等于k+1场的排列
首先我们定义为 n场赢m场,最大连赢大于等于k的排列
答案为
那现在考虑怎么计算ansk
首先我们可以枚举i为 1个k连胜,2个k连胜.......
然后进行容斥,奇加偶减
对于第个连胜
我们首先要在个输的场的空隙中选i个来放我们i个k连胜
然后接下来剩下的球任意放在的空隙里,可以有空
剩下球的数目为
等同于把n个球放进m个盒子里,可以有空盒子
不懂可以去看->(50条消息) [概率练习]n个小球放入m个盒子_Nightmare004的博客-CSDN博客_n个球放入m个盒子定理
然后把答案容斥一下可以得到
输出即可
代码如下:
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fer(i,a,b) for(int i=a;i<=b;++i)
#define der(i,a,b) for(int i=a;i>=b;--i)
#define all(x) (x).begin(),(x).end()
#define pll pair<int,int>
#define et cout<<'\n'
#define xx first
#define yy second
#define double long double
using namespace std;
const int mod=998244353;
const int N=1e6+10;
int qpow(int a,int b)
{
int t=1;
while(b)
{
if(b&1)t=t*a%mod;
a=a*a%mod;
b>>=1;
}
return t;
}
inline int mul(int x, int y) {
return (x * y) % mod;
}
int fc[N],gc[N];
inline void init(){
fc[0]=1;
for(int i=1;i<=500001;i++) fc[i]=mul(fc[i-1],i);
gc[500001]=qpow(fc[500001],mod-2);
for(int i=500001;i>=1;i--) gc[i-1]=mul(gc[i],i);
}
inline int C(int i,int j){
if(j>i)return 0;
return mul(mul(fc[i],gc[j]),gc[i-j]);
}//大组合数
int n,m,k;
int cal(int k){
int ans=0;
for(int i=1;i*k<=m;i++)
{
if(i&1)ans=(ans+C(n-m+1,i)*C(n-i*k,n-m)%mod)%mod;
else ans=(ans-C(n-m+1,i)*C(n-i*k,n-m)%mod)%mod;
}
return ans;
}
signed main()
{
init();
cin>>n>>m>>k;
int ans=0;
if(k==0)
{
cout<<(m==0)<<endl;
return 0;
}
ans=((cal(k)-cal(k+1))%mod+mod)%mod;
cout<<ans<<endl;
return 0;
}
生成函数/多项式快速幂
生成函数的解法更暴力好像一些
对于最多连胜小于等于k场
个空位,每个空位放1胜,2胜,......
这样每个函数生成函数为
而一共有个空位
所以对于答案为
的 m次项的系数
最后输出即可
代码如下
#include <bits/stdc++.h>
#define int long long
#define x first
#define y second
#define poly vector<int>
#define len(x) ((int)x.size())
using namespace std;
const int N = 3e5 + 5, G = 3, Ginv = 332748118, mod = 998244353;
int n, m, rev[N], lim, w, k;
poly f, g;
struct nd {
int x, y;
nd (int a = 0, int b = 0) {
x = a;
y = b;
}
nd operator * (const nd &a) const {
return nd((x * a.x % mod + y * a.y % mod * w % mod) % mod, (x * a.y % mod + y * a.x % mod) % mod);
}
};
int readmod() {
int x = 0;
char s = getchar();
while (s < '0' || s > '9') s = getchar();
while (s >= '0' && s <= '9') {
x = ((x << 1) + (x << 3) + (s - '0')) % mod;
s = getchar();
}
return x;
}
int qmi(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
nd qmi(nd a, int b, nd res = nd(1, 0)) {
while (b) {
if (b & 1) res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
bool check_sqrt(int a) {
return qmi(a, (mod - 1) / 2) == 1;
}
int get_w(int n) {
for (int a = rand(); ; a = rand()) {
if (!check_sqrt((a * a % mod - n + mod) % mod)) {
return w = (a * a % mod - n + mod) % mod, a;
}
}
}
int get_sqrt(int n) {
int a = get_w(n), x_1 = (qmi(nd(a, 1), (mod + 1) / 2)).x, x_2 = mod - x_1;
return x_1 < x_2 ? x_1 : x_2;
}
void polyinit(int n) {
for (lim = 1; lim < n; lim <<= 1);
for (int i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? lim >> 1 : 0);
}
void NTT(poly &f, int op) {
for (int i = 0; i < lim; i ++) {
if (i < rev[i]) swap(f[i], f[rev[i]]);
}
for (int mid = 1; mid < lim; mid <<= 1) {
int Gn = qmi(op == 1 ? G : Ginv, (mod - 1) / (mid << 1));
for (int i = 0; i < lim; i += mid * 2) {
for (int j = 0, G0 = 1; j < mid; j ++, G0 = G0 * Gn % mod) {
int x = f[i + j], y = G0 * f[i + j + mid] % mod;
f[i + j] = (x + y) % mod, f[i + j + mid] = (x - y + mod) % mod;
}
}
}
if (op == -1) {
int inv = qmi(lim, mod - 2);
for (int i = 0; i < lim; i ++) f[i] = f[i] * inv % mod;
}
}
poly operator + (poly f, int x) {
f[0] = (f[0] + x) % mod;
return f;
}
poly operator - (poly f, int x) {
f[0] = (f[0] - x + mod) % mod;
return f;
}
poly operator * (poly f, int x) {
for (int i = 0; i < len(f); i ++) f[i] = f[i] * x % mod;
return f;
}
poly operator + (poly f, poly g) {
poly res = f;
res.resize(max(len(f), len(g)));
for (int i = 0; i < len(g); i ++) res[i] = (res[i] + g[i]) % mod;
return res;
}
poly operator - (poly f, poly g) {
poly res = f;
res.resize(max(len(f), len(g)));
for (int i = 0; i < len(g); i ++) res[i] = (res[i] - g[i] + mod) % mod;
return res;
}
poly operator * (poly f, poly g) {
int n = len(f) + len(g) - 1;
polyinit(n), f.resize(lim), g.resize(lim);
NTT(f, 1), NTT(g, 1);
for (int i = 0; i < lim; i ++) f[i] = f[i] * g[i] % mod;
NTT(f, -1), f.resize(n);
return f;
}
poly polyinv(poly f, int n) {
poly res(1, qmi(f[0], mod - 2)), t;
for (int len = 2; len < n << 1; len <<= 1) {
t.resize(len), polyinit(len << 1);
for (int i = 0; i < len; i ++) t[i] = i < len(f) ? f[i] : 0;
t.resize(lim), res.resize(lim);
NTT(t, 1), NTT(res, 1);
for (int i = 0; i < lim; i ++) res[i] = res[i] * (2 - res[i] * t[i] % mod + mod) % mod;
NTT(res, -1), res.resize(len);
}
res.resize(n);
return res;
}
pair<poly, poly> polydiv(poly f, poly g) {
int n = len(f) - len(g) + 1;
poly q = f;
reverse(q.begin(), q.end()), q.resize(n);
reverse(g.begin(), g.end());
q = q * polyinv(g, n), q.resize(n);
reverse(q.begin(), q.end()), reverse(g.begin(), g.end());
poly r = f - g * q;
return {q, r};
}
poly polysqrt(poly f, int n) {
int inv2 = qmi(2, mod - 2);
poly b(1, get_sqrt(f[0])), c, d;
for (int len = 4; (len >> 2) < n; len <<= 1) {
c = f, c.resize(len >> 1), polyinit(len);
d = polyinv(b, len >> 1);
c.resize(lim), d.resize(lim);
NTT(c, 1), NTT(d, 1);
for (int i = 0; i < lim; i ++) c[i] = c[i] * d[i] % mod;
NTT(c, -1);
b.resize(len >> 1);
for (int i = 0; i < (len >> 1); i ++) b[i] = (c[i] + b[i]) % mod * inv2 % mod;
}
b.resize(lim);
if (mod - b[0] < b[0]) {
for (int i = 0; i < len(b); i ++) b[i] = (0 - b[i] + mod) % mod;
}
return b;
}
poly polyderiv(poly f) {
for (int i = 0; i < len(f) - 1; i ++) f[i] = f[i + 1] * (i + 1) % mod;
f.pop_back();
return f;
}
poly polyinteg(poly f) {
for (int i = len(f) - 1; i; i --) f[i] = f[i - 1] * qmi(i, mod - 2) % mod;
f[0] = 0;
return f;
}
poly polyln(poly f, int n) {
f = polyinteg(polyderiv(f) * polyinv(f, n));
f.resize(n);
return f;
}
poly polyexp(poly f, int n) {
poly b(1, 1), c;
for (int len = 2; (len >> 1) < n; len <<= 1) {
c = polyln(b, len);
for (int i = 0; i < len; i ++) c[i] = ((i < len(f) ? f[i] : 0) - c[i] + mod) % mod;
c = c + 1, b = b * c;
b.resize(len);
}
b.resize(n);
return b;
}
poly polyqmi(poly f, int n, int k) {
f = polyexp(polyln(f, n) * k, n);
return f;
}
signed main() {
cin >> n >> m >> k;
if (!m) {
cout << 1 << endl;
return 0;
} else if (!n) {
cout << 0 << endl;
return 0;
} else if (n < m) {
cout << 0 << endl;
return 0;
}
f.resize(k + 1);
for (int i = 0; i <= k; i ++) f[i] = 1;
poly res = polyqmi(f, m + 1, n - m + 1);
if (!k) {
cout << 0 << endl;
} else if (k == 1) {
cout << res[m] << endl;
} else {
f.resize(k);
for (int i = 0; i <= k - 1; i ++) f[i] = 1;
poly res2 = polyqmi(f, m + 1, n - m + 1);
cout << (res[m] - res2[m] + mod) % mod << endl;
}
}