常系数齐次线性递推

题目链接

题目描述

数列 { a n } \{a_n\} {an}满足 k k k阶线性递推关系: a n = ∑ i = 1 k f i a n − i ( n ≥ k ) a_n=\sum_{i=1}^kf_ia_{n-i} (n\ge k) an=i=1kfiani(nk)

现给定 n , k n,k n,k以及 f 1 , f 2 , . . . , f k , a 0 , a 1 , . . . , a k − 1 f_1,f_2,...,f_k,a_0,a_1,...,a_{k-1} f1,f2,...,fk,a0,a1,...,ak1,求 a n m o d    998244353 a_n \mod 998244353 anmod998244353的值。

数据范围

n ≤ 1 0 9 , k ≤ 32000 n\le 10^9, k\le 32000 n109,k32000

思路

F = [ 0 1 . . . 0 0 0 ⋱ 0 0 0 . . . 1 f k f k − 1 . . . f 1 ] F=\begin{bmatrix}0&1&...&0\\0&0&⋱&0\\0&0&...&1\\f_{k}&f_{k-1}&...&f_1\end{bmatrix} F=000fk100fk1.........001f1
[ a n a n + 1 . . . a n + k − 1 ] = F n [ a 0 a 1 . . . a k − 1 ] \begin{bmatrix}a_n\\a_{n+1}\\...\\a_{n+k-1}\end{bmatrix}=F^n\begin{bmatrix}a_0\\a_1\\...\\a_{k-1}\end{bmatrix} anan+1...an+k1=Fna0a1...ak1
F F F的特征多项式为 g ( λ ) = ∣ λ I − F ∣ = λ k − ∑ i = 1 k f i λ k − i g(\lambda)=|\lambda I-F|=\lambda^k-\sum_{i=1}^kf_i\lambda^{k-i} g(λ)=λIF=λki=1kfiλki
根据代数基本定理,可以把 g ( λ ) g(\lambda) g(λ)表示为 g ( λ ) = ∏ i = 1 t ( λ − λ i ) m i g(\lambda)=\prod_{i=1}^t(\lambda-\lambda_i)^{m_i} g(λ)=i=1t(λλi)mi,其中 λ i \lambda_i λi各不相同,且 ∑ i = 1 t m i = k \sum_{i=1}^t m_i=k i=1tmi=k
因此 a n = ∑ i = 1 t λ i n ( ∑ j = 1 m i c i j n j − 1 ) a_n=\sum_{i=1}^t\lambda_i^n(\sum_{j=1}^{m_i}c_{ij}n^{j-1}) an=i=1tλin(j=1micijnj1),其中 c i j c_{ij} cij可以通过待定系数法确定。
但是 λ i \lambda_i λi咋解? c i j c_{ij} cij咋求?好像有点麻烦。

这时,还需要用到Hamilton-Cayley定理—— g ( F ) = F k − ∑ i = 1 k f i F k − i = O g(F)=F^k-\sum_{i=1}^kf_iF^{k-i}=O g(F)=Fki=1kfiFki=O

r ( λ ) = λ n m o d    g ( λ ) = ∑ i = 0 k − 1 r i λ i r(\lambda)=\lambda^n\mod g(\lambda)=\sum_{i=0}^{k-1}r_i\lambda^i r(λ)=λnmodg(λ)=i=0k1riλi
[ a n a n + 1 . . . a n + k − 1 ] = r ( F ) [ a 0 a 1 . . . a k − 1 ] \begin{bmatrix}a_n\\a_{n+1}\\...\\a_{n+k-1}\end{bmatrix}=r(F)\begin{bmatrix}a_0\\a_1\\...\\a_{k-1}\end{bmatrix} anan+1...an+k1=r(F)a0a1...ak1
a n = ∑ i = 0 k − 1 r i a i a_n=\sum_{i=0}^{k-1}r_ia_i an=i=0k1riai

总结一下:

  1. 求出递推关系的特征多项式 g ( λ ) = λ k − ∑ i = 1 k f i λ k − i g(\lambda)=\lambda^k-\sum_{i=1}^kf_i\lambda^{k-i} g(λ)=λki=1kfiλki
  2. 求出 r ( λ ) = λ n m o d    g ( λ ) = ∑ i = 0 k − 1 r i λ i r(\lambda)=\lambda^n\mod g(\lambda)=\sum_{i=0}^{k-1}r_i\lambda^i r(λ)=λnmodg(λ)=i=0k1riλi。这一步需要多项式快速幂+取模。
  3. 求出 a n = ∑ i = 0 k − 1 r i a i a_n=\sum_{i=0}^{k-1}r_ia_i an=i=0k1riai

代码

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)

using namespace std;
const int N = 1048576;
const int mod = 998244353;
typedef vector<int> vi;
int pw(int x, int y) {
    int ret = 1;
    while (y) {
        if (y & 1) ret = 1ll * ret * x % mod;
        x = 1ll * x * x % mod;
        y >>= 1;
    }
    return ret;
}
int inv(int x) { return pw(x, mod - 2); }

int rev[N];
void get_rev(int n) {
    static int lim = 1;
    if (n == lim) return;
    lim = n;
    int bit = 0;
    while ((1 << bit) < n) ++bit;
    rev[0] = 0;
    rep(i, 1, n - 1) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
}
void DFT(int *a, int n, int dir) {
    get_rev(n);
    rep(i, 0, n - 1) {
        if (i < rev[i]) swap(a[i], a[rev[i]]);
    }
    for (int len = 1; (len << 1) <= n; len <<= 1) {
        int wn = pw(3, (mod - 1) / (len << 1));
        if (dir == -1) wn = inv(wn);
        for (int i = 0; i < n; i += len << 1) {
            int w = 1;
            rep(j, i, i + len - 1) {
                int tmp = 1ll * a[j + len] * w % mod;
                a[j + len] = (a[j] - tmp + mod) % mod;
                a[j] = (a[j] + tmp) % mod;
                w = 1ll * w * wn % mod;
            }
        }
    }
    if (dir == -1) {
        int invn = inv(n);
        rep(i, 0, n - 1) { a[i] = 1ll * a[i] * invn % mod; }
    }
}

vi operator+(vi a, vi b) {
    int n = max(a.size(), b.size());
    vi ret(n, 0);
    rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
    rep(i, 0, b.size() - 1) { ret[i] = (ret[i] + b[i]) % mod; }
    return ret;
}
vi operator-(vi a, vi b) {
    int n = max(a.size(), b.size());
    vi ret(n, 0);
    rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
    rep(i, 0, b.size() - 1) { ret[i] = (ret[i] - b[i] + mod) % mod; }
    return ret;
}
vi operator*(vi a, vi b) {
    static int ta[N], tb[N];
    int n = a.size() + b.size() - 1;
    int lim = 1;
    while (lim < n) lim <<= 1;
    rep(i, 0, lim - 1) {
        ta[i] = (i < a.size()) ? a[i] : 0;
        tb[i] = (i < b.size()) ? b[i] : 0;
    }
    DFT(ta, lim, 1);
    DFT(tb, lim, 1);
    rep(i, 0, lim - 1) { ta[i] = 1ll * ta[i] * tb[i] % mod; }
    DFT(ta, lim, -1);
    vi ret;
    rep(i, 0, n - 1) { ret.push_back(ta[i]); }
    return ret;
}

vi inv(vi a) {
    static int ta[N], tb[N];
    int n = a.size();
    int lim = 1;
    while (lim < n) lim <<= 1;
    tb[0] = inv(a[0]);
    rep(i, 1, (lim << 1) - 1) tb[i] = 0;
    for (int len = 2; len <= lim; len <<= 1) {
        rep(i, 0, (len << 1) - 1) { ta[i] = (i < len && i < n) ? a[i] : 0; }
        DFT(ta, len << 1, 1);
        DFT(tb, len << 1, 1);
        rep(i, 0, (len << 1) - 1) {
            tb[i] = (2ll - 1ll * ta[i] * tb[i] % mod + mod) * tb[i] % mod;
        }
        DFT(tb, len << 1, -1);
        rep(i, len, (len << 1) - 1) tb[i] = 0;
    }
    vi ret;
    rep(i, 0, n - 1) { ret.push_back(tb[i]); }
    return ret;
}

vi operator/(vi a, vi b) {
    int n = a.size(), m = b.size();
    if (n < m) return vi(1, 0);
    vi ar = a, br = b;
    reverse(ar.begin(), ar.end());
    reverse(br.begin(), br.end());
    ar.resize(n - m + 1);
    br.resize(n - m + 1);
    vi c = ar * inv(br);
    c.resize(n - m + 1);
    reverse(c.begin(), c.end());
    return c;
}

vi operator%(vi a, vi b) {
    int m = b.size();
    vi r = a - a / b * b;
    r.resize(m - 1);
    return r;
}

vi pw(vi a, int k, vi p) {
    static int ret[N], ta[N], tp[N], tpr[N], q[N];
    int m = p.size();
    int n = m * 2 - 3;
    int lim = 1;
    while (lim < m) lim <<= 1;
    // rep(i, 0, m - 1) printf("%d ", p[i]);
    // printf(" (p)\n");
    vi aa = a % p;
    if (m == 2) return vi(1, pw(aa[0], k));
    rep(i, 0, (lim << 1) - 1) { ta[i] = (i < m - 1) ? aa[i] : 0; }
    // rep(i, 0, m - 2) printf("%d ", ta[i]);
    // printf(" (a)\n");

    vi pr = p;
    reverse(pr.begin(), pr.end());
    pr.resize(n - m + 1);
    pr = inv(pr);
    rep(i, 0, (lim << 1) - 1) {
        tp[i] = (i < m) ? p[i] : 0;
        tpr[i] = (i < n - m + 1) ? pr[i] : 0;
    }
    DFT(tp, lim << 1, 1);
    DFT(tpr, lim << 1, 1);

    ret[0] = 1;
    rep(i, 1, (lim << 1) - 1) ret[i] = 0;

    while (k) {
        DFT(ta, lim << 1, 1);
        if (k & 1) {
            DFT(ret, lim << 1, 1);

            rep(i, 0, (lim << 1) - 1) ret[i] = 1ll * ret[i] * ta[i] % mod;

            DFT(ret, lim << 1, -1);

            // rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
            // printf(" (ret)\n");

            // get mod
            rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ret[i] : 0; }
            reverse(q, q + n);
            rep(i, n - m + 1, n - 1) q[i] = 0;
            DFT(q, lim << 1, 1);
            rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
            DFT(q, lim << 1, -1);
            reverse(q, q + n - m + 1);
            rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;

            DFT(q, lim << 1, 1);
            rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
            DFT(q, lim << 1, -1);
            rep(i, 0, (lim << 1) - 1) { ret[i] = (ret[i] - q[i] + mod) % mod; }
            /*
            rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
            printf(" (ret mod p)\n");
            */
        }
        rep(i, 0, (lim << 1) - 1) ta[i] = 1ll * ta[i] * ta[i] % mod;
        DFT(ta, lim << 1, -1);

        // rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
        // printf(" (new a)\n");

        // get mod
        rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ta[i] : 0; }
        reverse(q, q + n);
        rep(i, n - m + 1, n - 1) q[i] = 0;
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (qr)\n");
        DFT(q, lim << 1, 1);
        rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
        DFT(q, lim << 1, -1);
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (qr * inv(pr))\n");
        reverse(q, q + n - m + 1);
        rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
        // rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
        // printf(" (new a / p)\n");

        DFT(q, lim << 1, 1);
        rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
        DFT(q, lim << 1, -1);
        rep(i, 0, (lim << 1) - 1) { ta[i] = (ta[i] - q[i] + mod) % mod; }

        // rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
        // printf(" (new a mod p)\n");
        k >>= 1;
    }

    vi c;
    rep(i, 0, m - 2) c.push_back(ret[i]);
    return c;
}

int linear(vi g, vi a, int n) {
    int k = g.size() - 1;
    vi t = {0, 1};
    vi r = pw(t, n, g);
    int ret = 0;
    rep(i, 0, k - 1) { ret = (ret + 1ll * r[i] * a[i]) % mod; }
    return ret;
}

void test_linear() {
    static int f[N];
    int n, k;
    scanf("%d%d", &n, &k);
    rep(i, 1, k) {
        scanf("%d", &f[i]);
        f[i] = (f[i] + mod) % mod;
    }
    vi a;
    int x;
    rep(i, 0, k - 1) {
        scanf("%d", &x);
        a.push_back((x + mod) % mod);
    }
    vi g;
    per(i, k, 1) { g.push_back((mod - f[i]) % mod); }
    g.push_back(1);

    int ret = linear(g, a, n);
    printf("%d", ret);
}

int n, t, m;
int p[N];
vi solve(int l, int r) {
    if (l == r) {
        vi g(1, 1);
        int c = 1, pp = 1;
        rep(j, 1, m) {
            pp = 1ll * pp * (mod - p[l]) % mod;
            c = 1ll * c * (m - j + 1) % mod * inv(j) % mod;
            g.push_back(1ll * c * pp % mod);
        }
        return g;
    }
    int mid = (l + r) / 2;
    return solve(l, mid) * solve(mid + 1, r);
}
void work() {
    scanf("%d%d%d", &n, &t, &m);
    rep(i, 1, t) { scanf("%d", &p[i]); }
    vi f = solve(1, t);
    // rep(i, 0, m * t) { printf("%d ", f[i]); }
    vi h(1, 1);
    int c = 1;
    rep(j, 1, m * t) {
        c = 1ll * c * (m * t - j + 1) % mod * inv(mod - j) % mod;
        h.push_back(c);
    }
    h = h * inv(f);
    // rep(i, 0, m * t) { printf("%d ", h[i]); }
    int ret;
    if (n <= m * t) {
        ret = h[n];
    } else {
        reverse(f.begin(), f.end());
        vi a(h.begin() + 1, h.end());
        --n;
        ret = linear(f, a, n);
    }
    printf("%d", ret);
}

int main() {
    test_linear();
    return 0;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值