【2021牛客多校8H】Scholomance Academy (线性递推)

题目链接

题目大意

G ( N ) = ∑ k 1 + k 2 + . . . + k t = N F ( p 1 k 1 p 2 k 2 . . . p t k t ) G(N)=\sum_{k_1+k_2+...+k_t=N}F(p_1^{k_1}p_2^{k_2}...p_t^{k_t}) G(N)=k1+k2+...+kt=NF(p1k1p2k2...ptkt)
F ( n ) = ∑ a 1 a 2 . . . a m = n φ ( a 1 ) φ ( a 2 ) . . . φ ( a m ) F(n)=\sum_{a_1a_2...a_m=n}\varphi(a_1)\varphi(a_2)...\varphi(a_m) F(n)=a1a2...am=nφ(a1)φ(a2)...φ(am)
给定 N , t , m N,t,m N,t,m,求 G ( N ) G(N) G(N)。(mod 998244353)

思路

F ( n ) F(n) F(n) 为积性函数,只需考虑 F ( p k ) F(p^k) F(pk)
h p ( x ) = ∑ k = 0 ∞ φ ( p k ) x k = 1 − x 1 − p x h_p(x)=\sum_{k=0}^\infty\varphi(p^k)x^k=\frac{1-x}{1-px} hp(x)=k=0φ(pk)xk=1px1x

F ( p k ) = [ x k ] ( h p ( x ) ) m F(p^k)=[x^k](h_p(x))^m F(pk)=[xk](hp(x))m
生成函数
f p ( x ) = ∑ k = 0 ∞ F ( p k ) x k = ( h p ( x ) ) m = ( 1 − x 1 − p x ) m f_p(x)=\sum_{k=0}^\infty F(p^k)x^k=(h_p(x))^m=(\frac{1-x}{1-px})^m fp(x)=k=0F(pk)xk=(hp(x))m=(1px1x)m

G ( N ) = [ x N ] ∏ i = 1 t f p i ( x ) = [ x N ] ∏ i = 1 t ( 1 − x 1 − p i x ) m G(N)=[x^N]\prod_{i=1}^tf_{p_i}(x)=[x^N]\prod_{i=1}^t(\frac{1-x}{1-p_ix})^m G(N)=[xN]i=1tfpi(x)=[xN]i=1t(1pix1x)m
可以求出前 m t mt mt 项,然后利用线性递推求出第 N N N 项。
由于式子是分式,因此求出前 m t mt mt 项,递推式就可以比较方便地进行表示。
之前写过一篇常系数齐次线性递推。利用该做法即可。

代码

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)

using namespace std;
const int N = 1048576;
const int mod = 998244353;
typedef vector<int> vi;
int pw(int x, int y) {
    int ret = 1;
    while (y) {
        if (y & 1) ret = 1ll * ret * x % mod;
        x = 1ll * x * x % mod;
        y >>= 1;
    }
    return ret;
}
int inv(int x) { return pw(x, mod - 2); }

int rev[N];
void get_rev(int n) {
    static int lim = 1;
    if (n == lim) return;
    lim = n;
    int bit = 0;
    while ((1 << bit) < n) ++bit;
    rev[0] = 0;
    rep(i, 1, n - 1) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
}
void DFT(int *a, int n, int dir) {
    get_rev(n);
    rep(i, 0, n - 1) {
        if (i < rev[i]) swap(a[i], a[rev[i]]);
    }
    for (int len = 1; (len << 1) <= n; len <<= 1) {
        int wn = pw(3, (mod - 1) / (len << 1));
        if (dir == -1) wn = inv(wn);
        for (int i = 0; i < n; i += len << 1) {
            int w = 1;
            rep(j, i, i + len - 1) {
                int tmp = 1ll * a[j + len] * w % mod;
                a[j + len] = (a[j] - tmp + mod) % mod;
                a[j] = (a[j] + tmp) % mod;
                w = 1ll * w * wn % mod;
            }
        }
    }
    if (dir == -1) {
        int invn = inv(n);
        rep(i, 0, n - 1) { a[i] = 1ll * a[i] * invn % mod; }
    }
}

vi operator+(vi a, vi b) {
    int n = max(a.size(), b.size());
    vi ret(n, 0);
    rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
    rep(i, 0, b.size() - 1) { ret[i] = (ret[i] + b[i]) % mod; }
    return ret;
}
vi operator-(vi a, vi b) {
    int n = max(a.size(), b.size());
    vi ret(n, 0);
    rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
    rep(i, 0, b.size() - 1) { ret[i] = (ret[i] - b[i] + mod) % mod; }
    return ret;
}
vi operator*(vi a, vi b) {
    static int ta[N], tb[N];
    int n = a.size() + b.size() - 1;
    int lim = 1;
    while (lim < n) lim <<= 1;
    rep(i, 0, lim - 1) {
        ta[i] = (i < a.size()) ? a[i] : 0;
        tb[i] = (i < b.size()) ? b[i] : 0;
    }
    DFT(ta, lim, 1);
    DFT(tb, lim, 1);
    rep(i, 0, lim - 1) { ta[i] = 1ll * ta[i] * tb[i] % mod; }
    DFT(ta, lim, -1);
    vi ret;
    rep(i, 0, n - 1) { ret.push_back(ta[i]); }
    return ret;
}

vi inv(vi a) {
    static int ta[N], tb[N];
    int n = a.size();
    int lim = 1;
    while (lim < n) lim <<= 1;
    tb[0] = inv(a[0]);
    rep(i, 1, (lim << 1) - 1) tb[i] = 0;
    for (int len = 2; len <= lim; len <<= 1) {
        rep(i, 0, (len << 1) - 1) { ta[i] = (i < len && i < n) ? a[i] : 0; }
        DFT(ta, len << 1, 1);
        DFT(tb, len << 1, 1);
        rep(i, 0, (len << 1) - 1) {
            tb[i] = (2ll - 1ll * ta[i] * tb[i] % mod + mod) * tb[i] % mod;
        }
        DFT(tb, len << 1, -1);
        rep(i, len, (len << 1) - 1) tb[i] = 0;
    }
    vi ret;
    rep(i, 0, n - 1) { ret.push_back(tb[i]); }
    return ret;
}

vi operator/(vi a, vi b) {
    int n = a.size(), m = b.size();
    if (n < m) return vi(1, 0);
    vi ar = a, br = b;
    reverse(ar.begin(), ar.end());
    reverse(br.begin(), br.end());
    ar.resize(n - m + 1);
    br.resize(n - m + 1);
    vi c = ar * inv(br);
    c.resize(n - m + 1);
    reverse(c.begin(), c.end());
    return c;
}

vi operator%(vi a, vi b) {
    int m = b.size();
    vi r = a - a / b * b;
    r.resize(m - 1);
    return r;
}

vi pw(vi a, int k, vi p) {
    static int ret[N], ta[N], tp[N], tpr[N], q[N];
    int m = p.size();
    int n = m * 2 - 3;
    int lim = 1;
    while (lim < m) lim <<= 1;
    // rep(i, 0, m - 1) printf("%d ", p[i]);
    // printf(" (p)\n");
    vi aa = a % p;
    if (m == 2) return vi(1, pw(aa[0], k));
    rep(i, 0, (lim << 1) - 1) { ta[i] = (i < m - 1) ? aa[i] : 0; }
    // rep(i, 0, m - 2) printf("%d ", ta[i]);
    // printf(" (a)\n");

    vi pr = p;
    reverse(pr.begin(), pr.end());
    pr.resize(n - m + 1);
    pr = inv(pr);
    rep(i, 0, (lim << 1) - 1) {
        tp[i] = (i < m) ? p[i] : 0;
        tpr[i] = (i < n - m + 1) ? pr[i] : 0;
    }
    DFT(tp, lim << 1, 1);
    DFT(tpr, lim << 1, 1);

    ret[0] = 1;
    rep(i, 1, (lim << 1) - 1) ret[i] = 0;

    while (k) {
        DFT(ta, lim << 1, 1);
        if (k & 1) {
            DFT(ret, lim << 1, 1);

            rep(i, 0, (lim << 1) - 1) ret[i] = 1ll * ret[i] * ta[i] % mod;

            DFT(ret, lim << 1, -1);

            // rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
            // printf(" (ret)\n");

            // get mod
            rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ret[i] : 0; }
            reverse(q, q + n);
            rep(i, n - m + 1, n - 1) q[i] = 0;
            DFT(q, lim << 1, 1);
            rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
            DFT(q, lim << 1, -1);
            reverse(q, q + n - m + 1);
            rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;

            DFT(q, lim << 1, 1);
            rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
            DFT(q, lim << 1, -1);
            rep(i, 0, (lim << 1) - 1) { ret[i] = (ret[i] - q[i] + mod) % mod; }
            /*
            rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
            printf(" (ret mod p)\n");
            */
        }
        rep(i, 0, (lim << 1) - 1) ta[i] = 1ll * ta[i] * ta[i] % mod;
        DFT(ta, lim << 1, -1);

        // rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
        // printf(" (new a)\n");

        // get mod
        rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ta[i] : 0; }
        reverse(q, q + n);
        rep(i, n - m + 1, n - 1) q[i] = 0;
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (qr)\n");
        DFT(q, lim << 1, 1);
        rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
        DFT(q, lim << 1, -1);
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (qr * inv(pr))\n");
        reverse(q, q + n - m + 1);
        rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (new a / p)\n");

        DFT(q, lim << 1, 1);
        rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
        DFT(q, lim << 1, -1);
        rep(i, 0, (lim << 1) - 1) { ta[i] = (ta[i] - q[i] + mod) % mod; }

        // rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
        // printf(" (new a mod p)\n");
        k >>= 1;
    }

    vi c;
    rep(i, 0, m - 2) c.push_back(ret[i]);
    return c;
}

int linear(vi g, vi a, int n) {
    int k = g.size() - 1;
    vi t = {0, 1};
    vi r = pw(t, n, g);
    int ret = 0;
    rep(i, 0, k - 1) { ret = (ret + 1ll * r[i] * a[i]) % mod; }
    return ret;
}

void test_linear() {
    static int f[N];
    int n, k;
    scanf("%d%d", &n, &k);
    rep(i, 1, k) {
        scanf("%d", &f[i]);
        f[i] = (f[i] + mod) % mod;
    }
    vi a;
    int x;
    rep(i, 0, k - 1) {
        scanf("%d", &x);
        a.push_back((x + mod) % mod);
    }
    vi g;
    per(i, k, 1) { g.push_back((mod - f[i]) % mod); }
    g.push_back(1);

    int ret = linear(g, a, n);
    printf("%d", ret);
}

int n, t, m;
int p[N];
vi solve(int l, int r) {
    if (l == r) {
        vi g(1, 1);
        int c = 1, pp = 1;
        rep(j, 1, m) {
            pp = 1ll * pp * (mod - p[l]) % mod;
            c = 1ll * c * (m - j + 1) % mod * inv(j) % mod;
            g.push_back(1ll * c * pp % mod);
        }
        return g;
    }
    int mid = (l + r) / 2;
    return solve(l, mid) * solve(mid + 1, r);
}
void work() {
    scanf("%d%d%d", &n, &t, &m);
    rep(i, 1, t) { scanf("%d", &p[i]); }
    vi f = solve(1, t);
    // rep(i, 0, m * t) { printf("%d ", f[i]); }
    vi h(1, 1);
    int c = 1;
    rep(j, 1, m * t) {
        c = 1ll * c * (m * t - j + 1) % mod * inv(mod - j) % mod;
        h.push_back(c);
    }
    h = h * inv(f);
    // rep(i, 0, m * t) { printf("%d ", h[i]); }
    int ret;
    if (n <= m * t) {
        ret = h[n];
    } else {
        reverse(f.begin(), f.end());
        vi a(h.begin() + 1, h.end());
        --n;
        ret = linear(f, a, n);
    }
    printf("%d", ret);
}

int main() {
    work();
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值