题目描述
数列 { a n } \{a_n\} {an}满足 k k k阶线性递推关系: a n = ∑ i = 1 k f i a n − i ( n ≥ k ) a_n=\sum_{i=1}^kf_ia_{n-i} (n\ge k) an=∑i=1kfian−i(n≥k)。
现给定 n , k n,k n,k以及 f 1 , f 2 , . . . , f k , a 0 , a 1 , . . . , a k − 1 f_1,f_2,...,f_k,a_0,a_1,...,a_{k-1} f1,f2,...,fk,a0,a1,...,ak−1,求 a n m o d 998244353 a_n \mod 998244353 anmod998244353的值。
数据范围
n ≤ 1 0 9 , k ≤ 32000 n\le 10^9, k\le 32000 n≤109,k≤32000
思路
令
F
=
[
0
1
.
.
.
0
0
0
⋱
0
0
0
.
.
.
1
f
k
f
k
−
1
.
.
.
f
1
]
F=\begin{bmatrix}0&1&...&0\\0&0&⋱&0\\0&0&...&1\\f_{k}&f_{k-1}&...&f_1\end{bmatrix}
F=⎣⎢⎢⎡000fk100fk−1...⋱......001f1⎦⎥⎥⎤。
则
[
a
n
a
n
+
1
.
.
.
a
n
+
k
−
1
]
=
F
n
[
a
0
a
1
.
.
.
a
k
−
1
]
\begin{bmatrix}a_n\\a_{n+1}\\...\\a_{n+k-1}\end{bmatrix}=F^n\begin{bmatrix}a_0\\a_1\\...\\a_{k-1}\end{bmatrix}
⎣⎢⎢⎡anan+1...an+k−1⎦⎥⎥⎤=Fn⎣⎢⎢⎡a0a1...ak−1⎦⎥⎥⎤。
F
F
F的特征多项式为
g
(
λ
)
=
∣
λ
I
−
F
∣
=
λ
k
−
∑
i
=
1
k
f
i
λ
k
−
i
g(\lambda)=|\lambda I-F|=\lambda^k-\sum_{i=1}^kf_i\lambda^{k-i}
g(λ)=∣λI−F∣=λk−∑i=1kfiλk−i。
根据代数基本定理,可以把
g
(
λ
)
g(\lambda)
g(λ)表示为
g
(
λ
)
=
∏
i
=
1
t
(
λ
−
λ
i
)
m
i
g(\lambda)=\prod_{i=1}^t(\lambda-\lambda_i)^{m_i}
g(λ)=∏i=1t(λ−λi)mi,其中
λ
i
\lambda_i
λi各不相同,且
∑
i
=
1
t
m
i
=
k
\sum_{i=1}^t m_i=k
∑i=1tmi=k。
因此
a
n
=
∑
i
=
1
t
λ
i
n
(
∑
j
=
1
m
i
c
i
j
n
j
−
1
)
a_n=\sum_{i=1}^t\lambda_i^n(\sum_{j=1}^{m_i}c_{ij}n^{j-1})
an=∑i=1tλin(∑j=1micijnj−1),其中
c
i
j
c_{ij}
cij可以通过待定系数法确定。
但是
λ
i
\lambda_i
λi咋解?
c
i
j
c_{ij}
cij咋求?好像有点麻烦。
这时,还需要用到Hamilton-Cayley定理—— g ( F ) = F k − ∑ i = 1 k f i F k − i = O g(F)=F^k-\sum_{i=1}^kf_iF^{k-i}=O g(F)=Fk−∑i=1kfiFk−i=O。
若
r
(
λ
)
=
λ
n
m
o
d
g
(
λ
)
=
∑
i
=
0
k
−
1
r
i
λ
i
r(\lambda)=\lambda^n\mod g(\lambda)=\sum_{i=0}^{k-1}r_i\lambda^i
r(λ)=λnmodg(λ)=∑i=0k−1riλi,
则
[
a
n
a
n
+
1
.
.
.
a
n
+
k
−
1
]
=
r
(
F
)
[
a
0
a
1
.
.
.
a
k
−
1
]
\begin{bmatrix}a_n\\a_{n+1}\\...\\a_{n+k-1}\end{bmatrix}=r(F)\begin{bmatrix}a_0\\a_1\\...\\a_{k-1}\end{bmatrix}
⎣⎢⎢⎡anan+1...an+k−1⎦⎥⎥⎤=r(F)⎣⎢⎢⎡a0a1...ak−1⎦⎥⎥⎤,
则
a
n
=
∑
i
=
0
k
−
1
r
i
a
i
a_n=\sum_{i=0}^{k-1}r_ia_i
an=∑i=0k−1riai。
总结一下:
- 求出递推关系的特征多项式 g ( λ ) = λ k − ∑ i = 1 k f i λ k − i g(\lambda)=\lambda^k-\sum_{i=1}^kf_i\lambda^{k-i} g(λ)=λk−∑i=1kfiλk−i。
- 求出 r ( λ ) = λ n m o d g ( λ ) = ∑ i = 0 k − 1 r i λ i r(\lambda)=\lambda^n\mod g(\lambda)=\sum_{i=0}^{k-1}r_i\lambda^i r(λ)=λnmodg(λ)=∑i=0k−1riλi。这一步需要多项式快速幂+取模。
- 求出 a n = ∑ i = 0 k − 1 r i a i a_n=\sum_{i=0}^{k-1}r_ia_i an=∑i=0k−1riai。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)
using namespace std;
const int N = 1048576;
const int mod = 998244353;
typedef vector<int> vi;
int pw(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) ret = 1ll * ret * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return ret;
}
int inv(int x) { return pw(x, mod - 2); }
int rev[N];
void get_rev(int n) {
static int lim = 1;
if (n == lim) return;
lim = n;
int bit = 0;
while ((1 << bit) < n) ++bit;
rev[0] = 0;
rep(i, 1, n - 1) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
}
void DFT(int *a, int n, int dir) {
get_rev(n);
rep(i, 0, n - 1) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int len = 1; (len << 1) <= n; len <<= 1) {
int wn = pw(3, (mod - 1) / (len << 1));
if (dir == -1) wn = inv(wn);
for (int i = 0; i < n; i += len << 1) {
int w = 1;
rep(j, i, i + len - 1) {
int tmp = 1ll * a[j + len] * w % mod;
a[j + len] = (a[j] - tmp + mod) % mod;
a[j] = (a[j] + tmp) % mod;
w = 1ll * w * wn % mod;
}
}
}
if (dir == -1) {
int invn = inv(n);
rep(i, 0, n - 1) { a[i] = 1ll * a[i] * invn % mod; }
}
}
vi operator+(vi a, vi b) {
int n = max(a.size(), b.size());
vi ret(n, 0);
rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
rep(i, 0, b.size() - 1) { ret[i] = (ret[i] + b[i]) % mod; }
return ret;
}
vi operator-(vi a, vi b) {
int n = max(a.size(), b.size());
vi ret(n, 0);
rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
rep(i, 0, b.size() - 1) { ret[i] = (ret[i] - b[i] + mod) % mod; }
return ret;
}
vi operator*(vi a, vi b) {
static int ta[N], tb[N];
int n = a.size() + b.size() - 1;
int lim = 1;
while (lim < n) lim <<= 1;
rep(i, 0, lim - 1) {
ta[i] = (i < a.size()) ? a[i] : 0;
tb[i] = (i < b.size()) ? b[i] : 0;
}
DFT(ta, lim, 1);
DFT(tb, lim, 1);
rep(i, 0, lim - 1) { ta[i] = 1ll * ta[i] * tb[i] % mod; }
DFT(ta, lim, -1);
vi ret;
rep(i, 0, n - 1) { ret.push_back(ta[i]); }
return ret;
}
vi inv(vi a) {
static int ta[N], tb[N];
int n = a.size();
int lim = 1;
while (lim < n) lim <<= 1;
tb[0] = inv(a[0]);
rep(i, 1, (lim << 1) - 1) tb[i] = 0;
for (int len = 2; len <= lim; len <<= 1) {
rep(i, 0, (len << 1) - 1) { ta[i] = (i < len && i < n) ? a[i] : 0; }
DFT(ta, len << 1, 1);
DFT(tb, len << 1, 1);
rep(i, 0, (len << 1) - 1) {
tb[i] = (2ll - 1ll * ta[i] * tb[i] % mod + mod) * tb[i] % mod;
}
DFT(tb, len << 1, -1);
rep(i, len, (len << 1) - 1) tb[i] = 0;
}
vi ret;
rep(i, 0, n - 1) { ret.push_back(tb[i]); }
return ret;
}
vi operator/(vi a, vi b) {
int n = a.size(), m = b.size();
if (n < m) return vi(1, 0);
vi ar = a, br = b;
reverse(ar.begin(), ar.end());
reverse(br.begin(), br.end());
ar.resize(n - m + 1);
br.resize(n - m + 1);
vi c = ar * inv(br);
c.resize(n - m + 1);
reverse(c.begin(), c.end());
return c;
}
vi operator%(vi a, vi b) {
int m = b.size();
vi r = a - a / b * b;
r.resize(m - 1);
return r;
}
vi pw(vi a, int k, vi p) {
static int ret[N], ta[N], tp[N], tpr[N], q[N];
int m = p.size();
int n = m * 2 - 3;
int lim = 1;
while (lim < m) lim <<= 1;
// rep(i, 0, m - 1) printf("%d ", p[i]);
// printf(" (p)\n");
vi aa = a % p;
if (m == 2) return vi(1, pw(aa[0], k));
rep(i, 0, (lim << 1) - 1) { ta[i] = (i < m - 1) ? aa[i] : 0; }
// rep(i, 0, m - 2) printf("%d ", ta[i]);
// printf(" (a)\n");
vi pr = p;
reverse(pr.begin(), pr.end());
pr.resize(n - m + 1);
pr = inv(pr);
rep(i, 0, (lim << 1) - 1) {
tp[i] = (i < m) ? p[i] : 0;
tpr[i] = (i < n - m + 1) ? pr[i] : 0;
}
DFT(tp, lim << 1, 1);
DFT(tpr, lim << 1, 1);
ret[0] = 1;
rep(i, 1, (lim << 1) - 1) ret[i] = 0;
while (k) {
DFT(ta, lim << 1, 1);
if (k & 1) {
DFT(ret, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) ret[i] = 1ll * ret[i] * ta[i] % mod;
DFT(ret, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
// printf(" (ret)\n");
// get mod
rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ret[i] : 0; }
reverse(q, q + n);
rep(i, n - m + 1, n - 1) q[i] = 0;
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
DFT(q, lim << 1, -1);
reverse(q, q + n - m + 1);
rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
DFT(q, lim << 1, -1);
rep(i, 0, (lim << 1) - 1) { ret[i] = (ret[i] - q[i] + mod) % mod; }
/*
rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
printf(" (ret mod p)\n");
*/
}
rep(i, 0, (lim << 1) - 1) ta[i] = 1ll * ta[i] * ta[i] % mod;
DFT(ta, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
// printf(" (new a)\n");
// get mod
rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ta[i] : 0; }
reverse(q, q + n);
rep(i, n - m + 1, n - 1) q[i] = 0;
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (qr)\n");
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
DFT(q, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (qr * inv(pr))\n");
reverse(q, q + n - m + 1);
rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (new a / p)\n");
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
DFT(q, lim << 1, -1);
rep(i, 0, (lim << 1) - 1) { ta[i] = (ta[i] - q[i] + mod) % mod; }
// rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
// printf(" (new a mod p)\n");
k >>= 1;
}
vi c;
rep(i, 0, m - 2) c.push_back(ret[i]);
return c;
}
int linear(vi g, vi a, int n) {
int k = g.size() - 1;
vi t = {0, 1};
vi r = pw(t, n, g);
int ret = 0;
rep(i, 0, k - 1) { ret = (ret + 1ll * r[i] * a[i]) % mod; }
return ret;
}
void test_linear() {
static int f[N];
int n, k;
scanf("%d%d", &n, &k);
rep(i, 1, k) {
scanf("%d", &f[i]);
f[i] = (f[i] + mod) % mod;
}
vi a;
int x;
rep(i, 0, k - 1) {
scanf("%d", &x);
a.push_back((x + mod) % mod);
}
vi g;
per(i, k, 1) { g.push_back((mod - f[i]) % mod); }
g.push_back(1);
int ret = linear(g, a, n);
printf("%d", ret);
}
int n, t, m;
int p[N];
vi solve(int l, int r) {
if (l == r) {
vi g(1, 1);
int c = 1, pp = 1;
rep(j, 1, m) {
pp = 1ll * pp * (mod - p[l]) % mod;
c = 1ll * c * (m - j + 1) % mod * inv(j) % mod;
g.push_back(1ll * c * pp % mod);
}
return g;
}
int mid = (l + r) / 2;
return solve(l, mid) * solve(mid + 1, r);
}
void work() {
scanf("%d%d%d", &n, &t, &m);
rep(i, 1, t) { scanf("%d", &p[i]); }
vi f = solve(1, t);
// rep(i, 0, m * t) { printf("%d ", f[i]); }
vi h(1, 1);
int c = 1;
rep(j, 1, m * t) {
c = 1ll * c * (m * t - j + 1) % mod * inv(mod - j) % mod;
h.push_back(c);
}
h = h * inv(f);
// rep(i, 0, m * t) { printf("%d ", h[i]); }
int ret;
if (n <= m * t) {
ret = h[n];
} else {
reverse(f.begin(), f.end());
vi a(h.begin() + 1, h.end());
--n;
ret = linear(f, a, n);
}
printf("%d", ret);
}
int main() {
test_linear();
return 0;
}