多项式优化常系数齐次线性递推

多项式优化常系数齐次线性递推

参考

https://www.cnblogs.com/Troywar/p/9078013.html
https://www.cnblogs.com/cjyyb/p/10152566.html
https://www.cnblogs.com/BAJimH/p/10574975.html
https://blog.csdn.net/jokerwyt/article/details/85345981?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param

线性递推

给 出 长 为 k 的 a 数 列 < a 0 , a 1 . . . a k − 1 > 和 一 个 无 穷 数 列 f 的 前 k 项 < f 1 , f 2 . . . f k > , 求 f n 。 给出长为k的a数列<a_0,a_1...a_{k-1}>和一个无穷数列f的前k项<f_1,f_2...f_{k}>,求f_n。 ka<a0,a1...ak1>fk<f1,f2...fk>fn
f n = ∑ i = 1 k a i f k − i f_n=\sum_{i=1}^ka_if_{k-i} fn=i=1kaifki

不同做法的复杂度比较

  • 暴 力 O ( n k ) 暴力O(nk) O(nk)
  • 矩 阵 快 速 幂 优 化 O ( k 3 log ⁡ n ) 矩阵快速幂优化O(k^3\log n) O(k3logn)
  • 暴 力 多 项 式 快 速 幂 优 化 O ( k 2 log ⁡ n ) 暴力多项式快速幂优化O(k^2\log n) O(k2logn)
  • 快 速 幂 套 N T T ∣ 多 项 式 取 模 优 化 O ( k log ⁡ k log ⁡ n ) 快速幂套NTT|多项式取模优化O(k\log k\log n) NTTO(klogklogn)

求解思路

矩 阵 快 速 幂 求 线 性 地 推 , 从 一 个 初 始 矩 阵 开 始 递 推 , 用 矩 阵 乘 法 , 最 后 在 和 f 相 乘 得 答 案 。 矩阵快速幂求线性地推,从一个初始矩阵开始递推,用矩阵乘法,最后在和f相乘得答案。 线f
这 里 主 要 的 复 杂 度 在 于 矩 阵 的 阶 数 k , 如 果 k 很 大 很 大 , 那 还 不 如 直 接 暴 力 , 所 以 就 有 多 项 式 的 做 法 了 。 这里主要的复杂度在于矩阵的阶数k,如果k很大很大,那还不如直接暴力,所以就有多项式的做法了。 kk

和 快 速 幂 一 样 , 把 矩 阵 乘 法 换 成 多 项 式 乘 法 , 取 模 换 成 多 项 式 取 模 。 和快速幂一样,把矩阵乘法换成多项式乘法,取模换成多项式取模。

多 项 式 乘 法 可 以 用 N T T 加 速 。 多项式乘法可以用NTT加速。 NTT

多 项 式 取 模 : 多项式取模:
A ( x ) = B ( x ) D ( x ) + R ( x ) A(x)=B(x)D(x)+R(x) A(x)=B(x)D(x)+R(x)
已 知 A ( x ) 和 B ( x ) , 求 商 D ( x ) 和 余 数 R ( x ) 。 已知A(x)和B(x),求商D(x)和余数R(x)。 A(x)B(x)D(x)R(x)

步 骤 : 步骤:

  • 将 多 项 式 系 数 反 转 , 使 得 最 高 次 幂 为 n − m 。 设 反 转 之 后 为 A R ( x ) = B R ( x ) D R ( x )      m o d    x n − m + 1 将多项式系数反转,使得最高次幂为n-m。设反转之后为A_R(x)=B_R(x)D_R(x) \;\;mod \;x^{n-m+1} 使nmAR(x)=BR(x)DR(x)modxnm+1
  • D ( x ) = r e v e r s e ( A R ( x ) ∗ B R − 1 ( x ) ) , 即 A 乘 B 的 逆 再 反 转 即 可 。 D(x)=reverse(A_R(x)*B_R^{-1}(x)),即A乘B的逆再反转即可。 D(x)=reverse(AR(x)BR1(x))AB
  • R ( x ) 直 接 用 A ( x ) − B ( x ) D ( x ) 得 到 。 R(x)直接用A(x)-B(x)D(x)得到。 R(x)A(x)B(x)D(x)

然 后 就 到 为 什 么 可 以 用 多 项 式 处 理 常 系 数 齐 次 线 性 递 推 。 然后就到为什么可以用多项式处理常系数齐次线性递推。 线

由 于 笔 者 能 力 有 限 , 只 能 看 着 大 佬 们 的 博 客 敲 敲 模 板 , 详 细 解 法 不 再 赘 述 。 由于笔者能力有限,只能看着大佬们的博客敲敲模板,详细解法不再赘述。

整 理 一 下 思 路 : 整理一下思路:

已 知 f n , 通 过 以 下 步 骤 得 到 f 2 n : 已知f_n,通过以下步骤得到f_{2n}: fnf2n

  • 将 表 达 系 数 多 项 式 平 方 , 使 用 F F T 加 速 。 O ( k log ⁡ k ) 将表达系数多项式平方,使用FFT加速。O(k \log k) 使FFTO(klogk)
  • 将 求 得 的 多 项 式 对 特 征 多 项 式 取 模 。 O ( k log ⁡ ⁡ k ) 将求得的多项式对特征多项式取模。O ( k \log ⁡ k ) O(klogk)

因 此 , 要 求 得 f n , 从 f 1 倍 增 即 可 , 就 是 上 文 说 的 多 项 式 快 速 幂 。 而 代 码 里 的 一 些 操 作 就 是 黑 科 技 了 。 因此,要求得f_n, 从f_1倍增即可,就是上文说的多项式快速幂。而代码里的一些操作就是黑科技了。 fn,f1

笔 者 没 有 用 N T T , 直 接 用 的 任 意 模 数 M T T 。 使 用 方 法 为 : 笔者没有用NTT,直接用的任意模数MTT。使用方法为: NTTMTT使

inline void MTT(ll *x, ll *y, ll *z, int len)
// 多项式x与y相乘得到z并返回,len为乘法中需要的长度。

Code

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pdd;

#define INF 0x3f3f3f3f
#define lowbit(x) x & (-x)
#define mem(a, b) memset(a , b , sizeof(a))
#define FOR(i, x, n) for(int i = x;i <= n; i++)

 const ll mod = 998244353;
// const ll mod = 1e9 + 7;
// const double eps = 1e-6;
 const double PI = acos(-1);
// const double R = 0.57721566490153286060651209;

const int N = 3e5 + 10;

struct Complex {
    double x, y;
    Complex(double a = 0, double b = 0): x(a), y(b) {}
    Complex operator + (const Complex &rhs) { return Complex(x + rhs.x, y + rhs.y); }
    Complex operator - (const Complex &rhs) { return Complex(x - rhs.x, y - rhs.y); }
    Complex operator * (const Complex &rhs) { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
    Complex conj() { return Complex(x, -y); }
} w[N];

int tr[N];

ll quick_pow(ll a, ll b) {
    ll ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}

int getLen(int n) {
    int len = 1; while (len < (n << 1)) len <<= 1;
    for (int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? len >> 1 : 0);
    for (int i = 0; i < len; i++) w[i] = w[i] = Complex(cos(2 * PI * i / len), sin(2 * PI * i / len));
    return len;
}

void rever(ll *f, int n) { for(int i = 0, j = n - 1;i < j; i++, j--) swap(f[i], f[j]); }

void FFT(Complex *A, int len) {
    for (int i = 0; i < len; i++) if(i < tr[i]) swap(A[i], A[tr[i]]);
    for (int i = 2, lyc = len >> 1; i <= len; i <<= 1, lyc >>= 1)
        for (int j = 0; j < len; j += i) {
            Complex *l = A + j, *r = A + j + (i >> 1), *p = w;
            for (int k = 0; k < i >> 1; k++) {
                Complex tmp = *r * *p;
                *r = *l - tmp, *l = *l + tmp;
                ++l, ++r, p += lyc;
            }
        }
}

inline void MTT(ll *x, ll *y, ll *z, int len) {

    for (int i = 0; i < len; i++) (x[i] += mod) %= mod, (y[i] += mod) %= mod;
    static Complex a[N], b[N];
    static Complex dfta[N], dftb[N], dftc[N], dftd[N];

    for (int i = 0; i < len; i++) a[i] = Complex(x[i] & 32767, x[i] >> 15);
    for (int i = 0; i < len; i++) b[i] = Complex(y[i] & 32767, y[i] >> 15);
    FFT(a, len), FFT(b, len);
    for (int i = 0; i < len; i++) {
        int j = (len - i) & (len - 1);
        static Complex da, db, dc, dd;
        da = (a[i] + a[j].conj()) * Complex(0.5, 0);
        db = (a[i] - a[j].conj()) * Complex(0, -0.5);
        dc = (b[i] + b[j].conj()) * Complex(0.5, 0);
        dd = (b[i] - b[j].conj()) * Complex(0, -0.5);
        dfta[j] = da * dc;
        dftb[j] = da * dd;
        dftc[j] = db * dc;
        dftd[j] = db * dd;
    }
    for (int i = 0; i < len; i++) a[i] = dfta[i] + dftb[i] * Complex(0, 1);
    for (int i = 0; i < len; i++) b[i] = dftc[i] + dftd[i] * Complex(0, 1);
    FFT(a, len), FFT(b, len);
    for (int i = 0; i < len; i++) {
        int da = (ll)(a[i].x / len + 0.5) % mod;
        int db = (ll)(a[i].y / len + 0.5) % mod;
        int dc = (ll)(b[i].x / len + 0.5) % mod;
        int dd = (ll)(b[i].y / len + 0.5) % mod;
        z[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
    }
}

void Get_Inv(ll *f, ll *g, int n) {
    if(n == 1) { g[0] = quick_pow(f[0], mod - 2); return ; }
    Get_Inv(f, g, (n + 1) >> 1);

    int len = getLen(n);
    static ll c[N];
    for(int i = 0;i < len; i++) c[i] = i < n ? f[i] : 0;
    MTT(c, g, c, len); MTT(c, g, c, len);
    for(int i = 0;i < n; i++) g[i] = (2ll * g[i] - c[i] + mod) % mod;
    for(int i = n;i < len; i++) g[i] = 0;
    for(int i = 0;i < len; i++) c[i] = 0;
}

int len;
int n, k;
ll a[N], h[N];
ll ans[N], s[N];
ll invG[N], G[N];

void Mod(ll *f,ll *g) {
    static ll tmp[N];
    rever(f, k + k - 1);
    for(int i = 0;i < k; i++) tmp[i] = f[i];
    MTT(tmp, invG, tmp, len);
    for(int i = k - 1; i < len; i++) tmp[i] = 0;
    rever(f, k + k - 1); rever(tmp, k - 1);
    MTT(tmp, G, tmp, len);
    for(int i = 0;i < k; i++) g[i] = (f[i] + mod - tmp[i]) % mod;
    for(int i = k;i < len; i++) g[i] = 0;
    for(int i = 0;i < len; i++) tmp[i] = 0;
}

void fpow(int b) {
    s[1] = 1; ans[0] = 1;
    while(b) {
        if(b & 1) { MTT(ans, s, ans, len);
        Mod(ans, ans); }
        MTT(s, s, s, len);
        Mod(s, s);
        b >>= 1;
    }
}

ll DITI(ll *a, ll *h, ll *ans, int n, int k) {
    G[k] = 1; for(int i = 1;i <= k; i++) G[k - i] = (mod - a[i]) % mod;
    rever(G, k + 1);
    len = getLen(k + 1);
    Get_Inv(G, invG, k + 1);
    for(int i = k + 1;i < len; i++) invG[i] = 0;
    rever(G, k + 1);
    fpow(n);
    ll Ans = 0;
    for(int i = 0;i < k; i++) Ans = (Ans + 1ll * h[i] * ans[i] % mod) % mod;
    return Ans;
}

void solve()
{
    cin >> n >> k;
    for(int i = 1;i <= k; i++){ cin >> a[i]; a[i] = a[i] < 0 ? a[i] + mod : a[i]; }
    for(int i = 0;i < k; i++) { cin >> h[i]; h[i] = h[i] < 0 ? h[i] + mod : h[i]; }

    ll Ans = DITI(a, h, ans, n, k);
    cout << Ans << endl;
}

signed main() {
    ios_base::sync_with_stdio(false);
    //cin.tie(nullptr);
    //cout.tie(nullptr);
#ifdef FZT_ACM_LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    signed test_index_for_debug = 1;
    char acm_local_for_debug = 0;
    do {
        if (acm_local_for_debug == '$') exit(0);
        if (test_index_for_debug > 20)
            throw runtime_error("Check the stdin!!!");
        auto start_clock_for_debug = clock();
        solve();
        auto end_clock_for_debug = clock();
        cout << "Test " << test_index_for_debug << " successful" << endl;
        cerr << "Test " << test_index_for_debug++ << " Run Time: "
             << double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
        cout << "--------------------------------------------------" << endl;
    } while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
    solve();
#endif
    return 0;
}

− − − 多 项 式 是 真 的 难 ! ! ! − − ---多项式是真的难!!!--

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值