2021ccpc威海 M 810975(容斥/生产函数/多项式快速幂)

目录

容斥

 生成函数/多项式快速幂


题意很简单

n场比赛,赢了m场,最大连赢k场有多少种排列

容斥

一般这种恰好k场的题目会考虑先算出大于等于k场的排列减去大于等于k+1场的排列

首先我们定义ans_{k}为 n场赢m场,最大连赢大于等于k的排列

答案为ans_{k}-ans_{k+1}

那现在考虑怎么计算ansk

首先我们可以枚举i为 1个k连胜,2个k连胜.......

然后进行容斥,奇加偶减

对于第ik连胜

我们首先要在n-m个输的场的n-m+1空隙中选i个来放我们i个k连胜

                                          \binom{n-m+1}{i}

然后接下来剩下的球任意放在n-m+1的空隙里,可以有空

剩下球的数目为n-(n-m)-i*k=m-i*k

等同于把n个球放进m个盒子里,可以有空盒子

不懂可以去看->(50条消息) [概率练习]n个小球放入m个盒子_Nightmare004的博客-CSDN博客_n个球放入m个盒子定理

然后把答案容斥一下可以得到

ans_{k}=\sum_{i=1}^{m-i*k\geq 0}(-1)^{i+1}\binom{n-m+1}{i}\binom{n-i*k}{n-m}

输出ans_{k}-ans_{k+1}即可

代码如下:

#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fer(i,a,b) for(int i=a;i<=b;++i)
#define der(i,a,b) for(int i=a;i>=b;--i)
#define all(x) (x).begin(),(x).end()
#define pll pair<int,int>
#define et  cout<<'\n'
#define xx first
#define yy second
#define double long double
using namespace std;
const int mod=998244353;
const int N=1e6+10;
int qpow(int a,int b)
{
    int t=1;
    while(b)
    {
        if(b&1)t=t*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return t;
}
inline int mul(int x, int y) {
    return (x * y) % mod;
}
int fc[N],gc[N];
inline void init(){
    fc[0]=1;
    for(int i=1;i<=500001;i++) fc[i]=mul(fc[i-1],i);
    gc[500001]=qpow(fc[500001],mod-2);
    for(int i=500001;i>=1;i--) gc[i-1]=mul(gc[i],i);
}
inline int C(int i,int j){
	if(j>i)return 0;
	return mul(mul(fc[i],gc[j]),gc[i-j]);
}//大组合数
int n,m,k;
int cal(int k){
    int ans=0;
    for(int i=1;i*k<=m;i++)
    {
        if(i&1)ans=(ans+C(n-m+1,i)*C(n-i*k,n-m)%mod)%mod;
        else ans=(ans-C(n-m+1,i)*C(n-i*k,n-m)%mod)%mod;
    }
    return ans;
}
signed main()
{
    init();
    cin>>n>>m>>k;
    int ans=0;
    if(k==0)
    {
        cout<<(m==0)<<endl;
        return 0;
    }
    ans=((cal(k)-cal(k+1))%mod+mod)%mod;
    cout<<ans<<endl;
    return 0;
}

 生成函数/多项式快速幂

生成函数的解法更暴力好像一些

对于最多连胜小于等于k场

n-m+1个空位,每个空位放1胜,2胜,......

这样每个函数生成函数为

                          x^{1}+x^{2}+x^{3}......x^{k}

而一共有n-m+1个空位

所以对于f(k)答案为

                          (x^{1}+x^{2}+x^{3}......x^{k})^{n-m+1}

的 m次项的系数

最后输出f(k)-f(k+1)即可

代码如下

#include <bits/stdc++.h>
#define int long long
#define x first
#define y second
#define poly vector<int>
#define len(x) ((int)x.size())
using namespace std;
const int N = 3e5 + 5, G = 3, Ginv = 332748118, mod = 998244353;
int n, m, rev[N], lim, w, k;
poly f, g;
struct nd {
    int x, y;
    nd (int a = 0, int b = 0) {
        x = a;
        y = b;
    }
    nd operator * (const nd &a) const {
        return nd((x * a.x % mod + y * a.y % mod * w % mod) % mod, (x * a.y % mod + y * a.x % mod) % mod);
    }
};
int readmod() {
    int x = 0;
    char s = getchar();
    while (s < '0' || s > '9') s = getchar();
    while (s >= '0' && s <= '9') {
        x = ((x << 1) + (x << 3) + (s - '0')) % mod;
        s = getchar();
    }
    return x;
}
int qmi(int a, int b) {
    int res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
nd qmi(nd a, int b, nd res = nd(1, 0)) {
    while (b) {
        if (b & 1) res = res * a;
        a = a * a;
        b >>= 1;
    }
    return res;
}
bool check_sqrt(int a) {
    return qmi(a, (mod - 1) / 2) == 1;
}
int get_w(int n) {
    for (int a = rand(); ; a = rand()) {
        if (!check_sqrt((a * a % mod - n + mod) % mod)) {
            return w = (a * a % mod - n + mod) % mod, a;
        }
    }
}
int get_sqrt(int n) {
    int a = get_w(n), x_1 = (qmi(nd(a, 1), (mod + 1) / 2)).x, x_2 = mod - x_1;
    return x_1 < x_2 ? x_1 : x_2;
}
void polyinit(int n) {
    for (lim = 1; lim < n; lim <<= 1);
    for (int i = 0; i < lim; i ++) rev[i] = (rev[i >> 1] >> 1) | (i & 1 ? lim >> 1 : 0);
}
void NTT(poly &f, int op) {
    for (int i = 0; i < lim; i ++) {
        if (i < rev[i]) swap(f[i], f[rev[i]]);
    }
    for (int mid = 1; mid < lim; mid <<= 1) {
        int Gn = qmi(op == 1 ? G : Ginv, (mod - 1) / (mid << 1));
        for (int i = 0; i < lim; i += mid * 2) {
            for (int j = 0, G0 = 1; j < mid; j ++, G0 = G0 * Gn % mod) {
                int x = f[i + j], y = G0 * f[i + j + mid] % mod;
                f[i + j] = (x + y) % mod, f[i + j + mid] = (x - y + mod) % mod;
            }
        }
    }
    if (op == -1) {
        int inv = qmi(lim, mod - 2);
        for (int i = 0; i < lim; i ++) f[i] = f[i] * inv % mod;
    }
}
poly operator + (poly f, int x) {
    f[0] = (f[0] + x) % mod;
    return f;
}
poly operator - (poly f, int x) {
    f[0] = (f[0] - x + mod) % mod;
    return f;
}
poly operator * (poly f, int x) {
    for (int i = 0; i < len(f); i ++) f[i] = f[i] * x % mod;
    return f;
}
poly operator + (poly f, poly g) {
    poly res = f;
    res.resize(max(len(f), len(g)));
    for (int i = 0; i < len(g); i ++) res[i] = (res[i] + g[i]) % mod;
    return res;
}
poly operator - (poly f, poly g) {
    poly res = f;
    res.resize(max(len(f), len(g)));
    for (int i = 0; i < len(g); i ++) res[i] = (res[i] - g[i] + mod) % mod;
    return res;
}
poly operator * (poly f, poly g) {
    int n = len(f) + len(g) - 1;
    polyinit(n), f.resize(lim), g.resize(lim);
    NTT(f, 1), NTT(g, 1);
    for (int i = 0; i < lim; i ++) f[i] = f[i] * g[i] % mod;
    NTT(f, -1), f.resize(n);
    return f;
}
poly polyinv(poly f, int n) {
    poly res(1, qmi(f[0], mod - 2)), t;
    for (int len = 2; len < n << 1; len <<= 1) {
        t.resize(len), polyinit(len << 1);
        for (int i = 0; i < len; i ++) t[i] = i < len(f) ? f[i] : 0;
        t.resize(lim), res.resize(lim);
        NTT(t, 1), NTT(res, 1);
        for (int i = 0; i < lim; i ++) res[i] = res[i] * (2 - res[i] * t[i] % mod + mod) % mod;
		NTT(res, -1), res.resize(len);
    }
    res.resize(n);
    return res;
}
pair<poly, poly> polydiv(poly f, poly g) {
    int n = len(f) - len(g) + 1;
    poly q = f;
    reverse(q.begin(), q.end()), q.resize(n);
    reverse(g.begin(), g.end());
    q = q * polyinv(g, n), q.resize(n);
    reverse(q.begin(), q.end()), reverse(g.begin(), g.end());
    poly r = f - g * q;
    return {q, r};
}
poly polysqrt(poly f, int n) {
    int inv2 = qmi(2, mod - 2);
    poly b(1, get_sqrt(f[0])), c, d;
    for (int len = 4; (len >> 2) < n; len <<= 1) {
        c = f, c.resize(len >> 1), polyinit(len);
        d = polyinv(b, len >> 1);
        c.resize(lim), d.resize(lim);
        NTT(c, 1), NTT(d, 1);
        for (int i = 0; i < lim; i ++) c[i] = c[i] * d[i] % mod;
        NTT(c, -1);
        b.resize(len >> 1);
        for (int i = 0; i < (len >> 1); i ++) b[i] = (c[i] + b[i]) % mod * inv2 % mod;
    }
    b.resize(lim);
    if (mod - b[0] < b[0]) {
        for (int i = 0; i < len(b); i ++) b[i] = (0 - b[i] + mod) % mod;
    }
    return b;
}
poly polyderiv(poly f) {
    for (int i = 0; i < len(f) - 1; i ++) f[i] = f[i + 1] * (i + 1) % mod;
    f.pop_back();
    return f;
}
poly polyinteg(poly f) {
    for (int i = len(f) - 1; i; i --) f[i] = f[i - 1] * qmi(i, mod - 2) % mod;
    f[0] = 0;
    return f;
}
poly polyln(poly f, int n) {
    f = polyinteg(polyderiv(f) * polyinv(f, n));
    f.resize(n);
    return f;
}
poly polyexp(poly f, int n) {
    poly b(1, 1), c;
    for (int len = 2; (len >> 1) < n; len <<= 1) {
        c = polyln(b, len);
        for (int i = 0; i < len; i ++) c[i] = ((i < len(f) ? f[i] : 0) - c[i] + mod) % mod;
        c = c + 1, b = b * c;
        b.resize(len);
    }
    b.resize(n);
    return b;
}
poly polyqmi(poly f, int n, int k) {
    f = polyexp(polyln(f, n) * k, n);
    return f;
}
signed main() {
    cin >> n >> m >> k;
    if (!m) {
        cout << 1 << endl;
        return 0;
    } else if (!n) {
        cout << 0 << endl;
        return 0;
    } else if (n < m) {
        cout << 0 << endl;
        return 0;
    }
    f.resize(k + 1);
    for (int i = 0; i <= k; i ++) f[i] = 1;
    poly res = polyqmi(f, m + 1, n - m + 1);
    if (!k) {
        cout << 0 << endl;
    } else if (k == 1) {
        cout << res[m] << endl;
    } else {
        f.resize(k);
        for (int i = 0; i <= k - 1; i ++) f[i] = 1;
        poly res2 = polyqmi(f, m + 1, n - m + 1);
        cout << (res[m] - res2[m] + mod) % mod << endl;
    }
}

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值