多项式优化常系数齐次线性递推
参考
https://www.cnblogs.com/Troywar/p/9078013.html
https://www.cnblogs.com/cjyyb/p/10152566.html
https://www.cnblogs.com/BAJimH/p/10574975.html
https://blog.csdn.net/jokerwyt/article/details/85345981?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-1.channel_param
线性递推
给
出
长
为
k
的
a
数
列
<
a
0
,
a
1
.
.
.
a
k
−
1
>
和
一
个
无
穷
数
列
f
的
前
k
项
<
f
1
,
f
2
.
.
.
f
k
>
,
求
f
n
。
给出长为k的a数列<a_0,a_1...a_{k-1}>和一个无穷数列f的前k项<f_1,f_2...f_{k}>,求f_n。
给出长为k的a数列<a0,a1...ak−1>和一个无穷数列f的前k项<f1,f2...fk>,求fn。
f
n
=
∑
i
=
1
k
a
i
f
k
−
i
f_n=\sum_{i=1}^ka_if_{k-i}
fn=i=1∑kaifk−i
不同做法的复杂度比较
- 暴 力 O ( n k ) 暴力O(nk) 暴力O(nk)
- 矩 阵 快 速 幂 优 化 O ( k 3 log n ) 矩阵快速幂优化O(k^3\log n) 矩阵快速幂优化O(k3logn)
- 暴 力 多 项 式 快 速 幂 优 化 O ( k 2 log n ) 暴力多项式快速幂优化O(k^2\log n) 暴力多项式快速幂优化O(k2logn)
- 快 速 幂 套 N T T ∣ 多 项 式 取 模 优 化 O ( k log k log n ) 快速幂套NTT|多项式取模优化O(k\log k\log n) 快速幂套NTT∣多项式取模优化O(klogklogn)
求解思路
矩
阵
快
速
幂
求
线
性
地
推
,
从
一
个
初
始
矩
阵
开
始
递
推
,
用
矩
阵
乘
法
,
最
后
在
和
f
相
乘
得
答
案
。
矩阵快速幂求线性地推,从一个初始矩阵开始递推,用矩阵乘法,最后在和f相乘得答案。
矩阵快速幂求线性地推,从一个初始矩阵开始递推,用矩阵乘法,最后在和f相乘得答案。
这
里
主
要
的
复
杂
度
在
于
矩
阵
的
阶
数
k
,
如
果
k
很
大
很
大
,
那
还
不
如
直
接
暴
力
,
所
以
就
有
多
项
式
的
做
法
了
。
这里主要的复杂度在于矩阵的阶数k,如果k很大很大,那还不如直接暴力,所以就有多项式的做法了。
这里主要的复杂度在于矩阵的阶数k,如果k很大很大,那还不如直接暴力,所以就有多项式的做法了。
和 快 速 幂 一 样 , 把 矩 阵 乘 法 换 成 多 项 式 乘 法 , 取 模 换 成 多 项 式 取 模 。 和快速幂一样,把矩阵乘法换成多项式乘法,取模换成多项式取模。 和快速幂一样,把矩阵乘法换成多项式乘法,取模换成多项式取模。
多 项 式 乘 法 可 以 用 N T T 加 速 。 多项式乘法可以用NTT加速。 多项式乘法可以用NTT加速。
多
项
式
取
模
:
多项式取模:
多项式取模:
A
(
x
)
=
B
(
x
)
D
(
x
)
+
R
(
x
)
A(x)=B(x)D(x)+R(x)
A(x)=B(x)D(x)+R(x)
已
知
A
(
x
)
和
B
(
x
)
,
求
商
D
(
x
)
和
余
数
R
(
x
)
。
已知A(x)和B(x),求商D(x)和余数R(x)。
已知A(x)和B(x),求商D(x)和余数R(x)。
步 骤 : 步骤: 步骤:
- 将 多 项 式 系 数 反 转 , 使 得 最 高 次 幂 为 n − m 。 设 反 转 之 后 为 A R ( x ) = B R ( x ) D R ( x ) m o d x n − m + 1 将多项式系数反转,使得最高次幂为n-m。设反转之后为A_R(x)=B_R(x)D_R(x) \;\;mod \;x^{n-m+1} 将多项式系数反转,使得最高次幂为n−m。设反转之后为AR(x)=BR(x)DR(x)modxn−m+1
- D ( x ) = r e v e r s e ( A R ( x ) ∗ B R − 1 ( x ) ) , 即 A 乘 B 的 逆 再 反 转 即 可 。 D(x)=reverse(A_R(x)*B_R^{-1}(x)),即A乘B的逆再反转即可。 D(x)=reverse(AR(x)∗BR−1(x)),即A乘B的逆再反转即可。
- R ( x ) 直 接 用 A ( x ) − B ( x ) D ( x ) 得 到 。 R(x)直接用A(x)-B(x)D(x)得到。 R(x)直接用A(x)−B(x)D(x)得到。
然 后 就 到 为 什 么 可 以 用 多 项 式 处 理 常 系 数 齐 次 线 性 递 推 。 然后就到为什么可以用多项式处理常系数齐次线性递推。 然后就到为什么可以用多项式处理常系数齐次线性递推。
由 于 笔 者 能 力 有 限 , 只 能 看 着 大 佬 们 的 博 客 敲 敲 模 板 , 详 细 解 法 不 再 赘 述 。 由于笔者能力有限,只能看着大佬们的博客敲敲模板,详细解法不再赘述。 由于笔者能力有限,只能看着大佬们的博客敲敲模板,详细解法不再赘述。
整 理 一 下 思 路 : 整理一下思路: 整理一下思路:
已 知 f n , 通 过 以 下 步 骤 得 到 f 2 n : 已知f_n,通过以下步骤得到f_{2n}: 已知fn,通过以下步骤得到f2n:
- 将 表 达 系 数 多 项 式 平 方 , 使 用 F F T 加 速 。 O ( k log k ) 将表达系数多项式平方,使用FFT加速。O(k \log k) 将表达系数多项式平方,使用FFT加速。O(klogk)
- 将 求 得 的 多 项 式 对 特 征 多 项 式 取 模 。 O ( k log k ) 将求得的多项式对特征多项式取模。O ( k \log k ) 将求得的多项式对特征多项式取模。O(klogk)
因 此 , 要 求 得 f n , 从 f 1 倍 增 即 可 , 就 是 上 文 说 的 多 项 式 快 速 幂 。 而 代 码 里 的 一 些 操 作 就 是 黑 科 技 了 。 因此,要求得f_n, 从f_1倍增即可,就是上文说的多项式快速幂。而代码里的一些操作就是黑科技了。 因此,要求得fn,从f1倍增即可,就是上文说的多项式快速幂。而代码里的一些操作就是黑科技了。
笔 者 没 有 用 N T T , 直 接 用 的 任 意 模 数 M T T 。 使 用 方 法 为 : 笔者没有用NTT,直接用的任意模数MTT。使用方法为: 笔者没有用NTT,直接用的任意模数MTT。使用方法为:
inline void MTT(ll *x, ll *y, ll *z, int len)
// 多项式x与y相乘得到z并返回,len为乘法中需要的长度。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pdd;
#define INF 0x3f3f3f3f
#define lowbit(x) x & (-x)
#define mem(a, b) memset(a , b , sizeof(a))
#define FOR(i, x, n) for(int i = x;i <= n; i++)
const ll mod = 998244353;
// const ll mod = 1e9 + 7;
// const double eps = 1e-6;
const double PI = acos(-1);
// const double R = 0.57721566490153286060651209;
const int N = 3e5 + 10;
struct Complex {
double x, y;
Complex(double a = 0, double b = 0): x(a), y(b) {}
Complex operator + (const Complex &rhs) { return Complex(x + rhs.x, y + rhs.y); }
Complex operator - (const Complex &rhs) { return Complex(x - rhs.x, y - rhs.y); }
Complex operator * (const Complex &rhs) { return Complex(x * rhs.x - y * rhs.y, x * rhs.y + y * rhs.x); }
Complex conj() { return Complex(x, -y); }
} w[N];
int tr[N];
ll quick_pow(ll a, ll b) {
ll ans = 1;
while(b) {
if(b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
int getLen(int n) {
int len = 1; while (len < (n << 1)) len <<= 1;
for (int i = 0; i < len; i++) tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? len >> 1 : 0);
for (int i = 0; i < len; i++) w[i] = w[i] = Complex(cos(2 * PI * i / len), sin(2 * PI * i / len));
return len;
}
void rever(ll *f, int n) { for(int i = 0, j = n - 1;i < j; i++, j--) swap(f[i], f[j]); }
void FFT(Complex *A, int len) {
for (int i = 0; i < len; i++) if(i < tr[i]) swap(A[i], A[tr[i]]);
for (int i = 2, lyc = len >> 1; i <= len; i <<= 1, lyc >>= 1)
for (int j = 0; j < len; j += i) {
Complex *l = A + j, *r = A + j + (i >> 1), *p = w;
for (int k = 0; k < i >> 1; k++) {
Complex tmp = *r * *p;
*r = *l - tmp, *l = *l + tmp;
++l, ++r, p += lyc;
}
}
}
inline void MTT(ll *x, ll *y, ll *z, int len) {
for (int i = 0; i < len; i++) (x[i] += mod) %= mod, (y[i] += mod) %= mod;
static Complex a[N], b[N];
static Complex dfta[N], dftb[N], dftc[N], dftd[N];
for (int i = 0; i < len; i++) a[i] = Complex(x[i] & 32767, x[i] >> 15);
for (int i = 0; i < len; i++) b[i] = Complex(y[i] & 32767, y[i] >> 15);
FFT(a, len), FFT(b, len);
for (int i = 0; i < len; i++) {
int j = (len - i) & (len - 1);
static Complex da, db, dc, dd;
da = (a[i] + a[j].conj()) * Complex(0.5, 0);
db = (a[i] - a[j].conj()) * Complex(0, -0.5);
dc = (b[i] + b[j].conj()) * Complex(0.5, 0);
dd = (b[i] - b[j].conj()) * Complex(0, -0.5);
dfta[j] = da * dc;
dftb[j] = da * dd;
dftc[j] = db * dc;
dftd[j] = db * dd;
}
for (int i = 0; i < len; i++) a[i] = dfta[i] + dftb[i] * Complex(0, 1);
for (int i = 0; i < len; i++) b[i] = dftc[i] + dftd[i] * Complex(0, 1);
FFT(a, len), FFT(b, len);
for (int i = 0; i < len; i++) {
int da = (ll)(a[i].x / len + 0.5) % mod;
int db = (ll)(a[i].y / len + 0.5) % mod;
int dc = (ll)(b[i].x / len + 0.5) % mod;
int dd = (ll)(b[i].y / len + 0.5) % mod;
z[i] = (da + ((ll)(db + dc) << 15) + ((ll)dd << 30)) % mod;
}
}
void Get_Inv(ll *f, ll *g, int n) {
if(n == 1) { g[0] = quick_pow(f[0], mod - 2); return ; }
Get_Inv(f, g, (n + 1) >> 1);
int len = getLen(n);
static ll c[N];
for(int i = 0;i < len; i++) c[i] = i < n ? f[i] : 0;
MTT(c, g, c, len); MTT(c, g, c, len);
for(int i = 0;i < n; i++) g[i] = (2ll * g[i] - c[i] + mod) % mod;
for(int i = n;i < len; i++) g[i] = 0;
for(int i = 0;i < len; i++) c[i] = 0;
}
int len;
int n, k;
ll a[N], h[N];
ll ans[N], s[N];
ll invG[N], G[N];
void Mod(ll *f,ll *g) {
static ll tmp[N];
rever(f, k + k - 1);
for(int i = 0;i < k; i++) tmp[i] = f[i];
MTT(tmp, invG, tmp, len);
for(int i = k - 1; i < len; i++) tmp[i] = 0;
rever(f, k + k - 1); rever(tmp, k - 1);
MTT(tmp, G, tmp, len);
for(int i = 0;i < k; i++) g[i] = (f[i] + mod - tmp[i]) % mod;
for(int i = k;i < len; i++) g[i] = 0;
for(int i = 0;i < len; i++) tmp[i] = 0;
}
void fpow(int b) {
s[1] = 1; ans[0] = 1;
while(b) {
if(b & 1) { MTT(ans, s, ans, len);
Mod(ans, ans); }
MTT(s, s, s, len);
Mod(s, s);
b >>= 1;
}
}
ll DITI(ll *a, ll *h, ll *ans, int n, int k) {
G[k] = 1; for(int i = 1;i <= k; i++) G[k - i] = (mod - a[i]) % mod;
rever(G, k + 1);
len = getLen(k + 1);
Get_Inv(G, invG, k + 1);
for(int i = k + 1;i < len; i++) invG[i] = 0;
rever(G, k + 1);
fpow(n);
ll Ans = 0;
for(int i = 0;i < k; i++) Ans = (Ans + 1ll * h[i] * ans[i] % mod) % mod;
return Ans;
}
void solve()
{
cin >> n >> k;
for(int i = 1;i <= k; i++){ cin >> a[i]; a[i] = a[i] < 0 ? a[i] + mod : a[i]; }
for(int i = 0;i < k; i++) { cin >> h[i]; h[i] = h[i] < 0 ? h[i] + mod : h[i]; }
ll Ans = DITI(a, h, ans, n, k);
cout << Ans << endl;
}
signed main() {
ios_base::sync_with_stdio(false);
//cin.tie(nullptr);
//cout.tie(nullptr);
#ifdef FZT_ACM_LOCAL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
signed test_index_for_debug = 1;
char acm_local_for_debug = 0;
do {
if (acm_local_for_debug == '$') exit(0);
if (test_index_for_debug > 20)
throw runtime_error("Check the stdin!!!");
auto start_clock_for_debug = clock();
solve();
auto end_clock_for_debug = clock();
cout << "Test " << test_index_for_debug << " successful" << endl;
cerr << "Test " << test_index_for_debug++ << " Run Time: "
<< double(end_clock_for_debug - start_clock_for_debug) / CLOCKS_PER_SEC << "s" << endl;
cout << "--------------------------------------------------" << endl;
} while (cin >> acm_local_for_debug && cin.putback(acm_local_for_debug));
#else
solve();
#endif
return 0;
}
− − − 多 项 式 是 真 的 难 ! ! ! − − ---多项式是真的难!!!-- −−−多项式是真的难!!!−−