题目大意
G
(
N
)
=
∑
k
1
+
k
2
+
.
.
.
+
k
t
=
N
F
(
p
1
k
1
p
2
k
2
.
.
.
p
t
k
t
)
G(N)=\sum_{k_1+k_2+...+k_t=N}F(p_1^{k_1}p_2^{k_2}...p_t^{k_t})
G(N)=k1+k2+...+kt=N∑F(p1k1p2k2...ptkt)
F
(
n
)
=
∑
a
1
a
2
.
.
.
a
m
=
n
φ
(
a
1
)
φ
(
a
2
)
.
.
.
φ
(
a
m
)
F(n)=\sum_{a_1a_2...a_m=n}\varphi(a_1)\varphi(a_2)...\varphi(a_m)
F(n)=a1a2...am=n∑φ(a1)φ(a2)...φ(am)
给定
N
,
t
,
m
N,t,m
N,t,m,求
G
(
N
)
G(N)
G(N)。(mod 998244353)
思路
F
(
n
)
F(n)
F(n) 为积性函数,只需考虑
F
(
p
k
)
F(p^k)
F(pk)。
h
p
(
x
)
=
∑
k
=
0
∞
φ
(
p
k
)
x
k
=
1
−
x
1
−
p
x
h_p(x)=\sum_{k=0}^\infty\varphi(p^k)x^k=\frac{1-x}{1-px}
hp(x)=k=0∑∞φ(pk)xk=1−px1−x
则
F
(
p
k
)
=
[
x
k
]
(
h
p
(
x
)
)
m
F(p^k)=[x^k](h_p(x))^m
F(pk)=[xk](hp(x))m
生成函数
f
p
(
x
)
=
∑
k
=
0
∞
F
(
p
k
)
x
k
=
(
h
p
(
x
)
)
m
=
(
1
−
x
1
−
p
x
)
m
f_p(x)=\sum_{k=0}^\infty F(p^k)x^k=(h_p(x))^m=(\frac{1-x}{1-px})^m
fp(x)=k=0∑∞F(pk)xk=(hp(x))m=(1−px1−x)m
则
G
(
N
)
=
[
x
N
]
∏
i
=
1
t
f
p
i
(
x
)
=
[
x
N
]
∏
i
=
1
t
(
1
−
x
1
−
p
i
x
)
m
G(N)=[x^N]\prod_{i=1}^tf_{p_i}(x)=[x^N]\prod_{i=1}^t(\frac{1-x}{1-p_ix})^m
G(N)=[xN]i=1∏tfpi(x)=[xN]i=1∏t(1−pix1−x)m
可以求出前
m
t
mt
mt 项,然后利用线性递推求出第
N
N
N 项。
由于式子是分式,因此求出前
m
t
mt
mt 项,递推式就可以比较方便地进行表示。
之前写过一篇常系数齐次线性递推。利用该做法即可。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)
using namespace std;
const int N = 1048576;
const int mod = 998244353;
typedef vector<int> vi;
int pw(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) ret = 1ll * ret * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return ret;
}
int inv(int x) { return pw(x, mod - 2); }
int rev[N];
void get_rev(int n) {
static int lim = 1;
if (n == lim) return;
lim = n;
int bit = 0;
while ((1 << bit) < n) ++bit;
rev[0] = 0;
rep(i, 1, n - 1) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); }
}
void DFT(int *a, int n, int dir) {
get_rev(n);
rep(i, 0, n - 1) {
if (i < rev[i]) swap(a[i], a[rev[i]]);
}
for (int len = 1; (len << 1) <= n; len <<= 1) {
int wn = pw(3, (mod - 1) / (len << 1));
if (dir == -1) wn = inv(wn);
for (int i = 0; i < n; i += len << 1) {
int w = 1;
rep(j, i, i + len - 1) {
int tmp = 1ll * a[j + len] * w % mod;
a[j + len] = (a[j] - tmp + mod) % mod;
a[j] = (a[j] + tmp) % mod;
w = 1ll * w * wn % mod;
}
}
}
if (dir == -1) {
int invn = inv(n);
rep(i, 0, n - 1) { a[i] = 1ll * a[i] * invn % mod; }
}
}
vi operator+(vi a, vi b) {
int n = max(a.size(), b.size());
vi ret(n, 0);
rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
rep(i, 0, b.size() - 1) { ret[i] = (ret[i] + b[i]) % mod; }
return ret;
}
vi operator-(vi a, vi b) {
int n = max(a.size(), b.size());
vi ret(n, 0);
rep(i, 0, a.size() - 1) { ret[i] = (ret[i] + a[i]) % mod; }
rep(i, 0, b.size() - 1) { ret[i] = (ret[i] - b[i] + mod) % mod; }
return ret;
}
vi operator*(vi a, vi b) {
static int ta[N], tb[N];
int n = a.size() + b.size() - 1;
int lim = 1;
while (lim < n) lim <<= 1;
rep(i, 0, lim - 1) {
ta[i] = (i < a.size()) ? a[i] : 0;
tb[i] = (i < b.size()) ? b[i] : 0;
}
DFT(ta, lim, 1);
DFT(tb, lim, 1);
rep(i, 0, lim - 1) { ta[i] = 1ll * ta[i] * tb[i] % mod; }
DFT(ta, lim, -1);
vi ret;
rep(i, 0, n - 1) { ret.push_back(ta[i]); }
return ret;
}
vi inv(vi a) {
static int ta[N], tb[N];
int n = a.size();
int lim = 1;
while (lim < n) lim <<= 1;
tb[0] = inv(a[0]);
rep(i, 1, (lim << 1) - 1) tb[i] = 0;
for (int len = 2; len <= lim; len <<= 1) {
rep(i, 0, (len << 1) - 1) { ta[i] = (i < len && i < n) ? a[i] : 0; }
DFT(ta, len << 1, 1);
DFT(tb, len << 1, 1);
rep(i, 0, (len << 1) - 1) {
tb[i] = (2ll - 1ll * ta[i] * tb[i] % mod + mod) * tb[i] % mod;
}
DFT(tb, len << 1, -1);
rep(i, len, (len << 1) - 1) tb[i] = 0;
}
vi ret;
rep(i, 0, n - 1) { ret.push_back(tb[i]); }
return ret;
}
vi operator/(vi a, vi b) {
int n = a.size(), m = b.size();
if (n < m) return vi(1, 0);
vi ar = a, br = b;
reverse(ar.begin(), ar.end());
reverse(br.begin(), br.end());
ar.resize(n - m + 1);
br.resize(n - m + 1);
vi c = ar * inv(br);
c.resize(n - m + 1);
reverse(c.begin(), c.end());
return c;
}
vi operator%(vi a, vi b) {
int m = b.size();
vi r = a - a / b * b;
r.resize(m - 1);
return r;
}
vi pw(vi a, int k, vi p) {
static int ret[N], ta[N], tp[N], tpr[N], q[N];
int m = p.size();
int n = m * 2 - 3;
int lim = 1;
while (lim < m) lim <<= 1;
// rep(i, 0, m - 1) printf("%d ", p[i]);
// printf(" (p)\n");
vi aa = a % p;
if (m == 2) return vi(1, pw(aa[0], k));
rep(i, 0, (lim << 1) - 1) { ta[i] = (i < m - 1) ? aa[i] : 0; }
// rep(i, 0, m - 2) printf("%d ", ta[i]);
// printf(" (a)\n");
vi pr = p;
reverse(pr.begin(), pr.end());
pr.resize(n - m + 1);
pr = inv(pr);
rep(i, 0, (lim << 1) - 1) {
tp[i] = (i < m) ? p[i] : 0;
tpr[i] = (i < n - m + 1) ? pr[i] : 0;
}
DFT(tp, lim << 1, 1);
DFT(tpr, lim << 1, 1);
ret[0] = 1;
rep(i, 1, (lim << 1) - 1) ret[i] = 0;
while (k) {
DFT(ta, lim << 1, 1);
if (k & 1) {
DFT(ret, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) ret[i] = 1ll * ret[i] * ta[i] % mod;
DFT(ret, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
// printf(" (ret)\n");
// get mod
rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ret[i] : 0; }
reverse(q, q + n);
rep(i, n - m + 1, n - 1) q[i] = 0;
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
DFT(q, lim << 1, -1);
reverse(q, q + n - m + 1);
rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
DFT(q, lim << 1, -1);
rep(i, 0, (lim << 1) - 1) { ret[i] = (ret[i] - q[i] + mod) % mod; }
/*
rep(i, 0, (lim << 1) - 1) printf("%d ", ret[i]);
printf(" (ret mod p)\n");
*/
}
rep(i, 0, (lim << 1) - 1) ta[i] = 1ll * ta[i] * ta[i] % mod;
DFT(ta, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
// printf(" (new a)\n");
// get mod
rep(i, 0, (lim << 1) - 1) { q[i] = (i < n) ? ta[i] : 0; }
reverse(q, q + n);
rep(i, n - m + 1, n - 1) q[i] = 0;
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (qr)\n");
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tpr[i] % mod;
DFT(q, lim << 1, -1);
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (qr * inv(pr))\n");
reverse(q, q + n - m + 1);
rep(i, n - m + 1, (lim << 1) - 1) q[i] = 0;
// rep(i, 0, (lim << 1) - 1) printf("%d ", q[i]);
// printf(" (new a / p)\n");
DFT(q, lim << 1, 1);
rep(i, 0, (lim << 1) - 1) q[i] = 1ll * q[i] * tp[i] % mod;
DFT(q, lim << 1, -1);
rep(i, 0, (lim << 1) - 1) { ta[i] = (ta[i] - q[i] + mod) % mod; }
// rep(i, 0, (lim << 1) - 1) printf("%d ", ta[i]);
// printf(" (new a mod p)\n");
k >>= 1;
}
vi c;
rep(i, 0, m - 2) c.push_back(ret[i]);
return c;
}
int linear(vi g, vi a, int n) {
int k = g.size() - 1;
vi t = {0, 1};
vi r = pw(t, n, g);
int ret = 0;
rep(i, 0, k - 1) { ret = (ret + 1ll * r[i] * a[i]) % mod; }
return ret;
}
void test_linear() {
static int f[N];
int n, k;
scanf("%d%d", &n, &k);
rep(i, 1, k) {
scanf("%d", &f[i]);
f[i] = (f[i] + mod) % mod;
}
vi a;
int x;
rep(i, 0, k - 1) {
scanf("%d", &x);
a.push_back((x + mod) % mod);
}
vi g;
per(i, k, 1) { g.push_back((mod - f[i]) % mod); }
g.push_back(1);
int ret = linear(g, a, n);
printf("%d", ret);
}
int n, t, m;
int p[N];
vi solve(int l, int r) {
if (l == r) {
vi g(1, 1);
int c = 1, pp = 1;
rep(j, 1, m) {
pp = 1ll * pp * (mod - p[l]) % mod;
c = 1ll * c * (m - j + 1) % mod * inv(j) % mod;
g.push_back(1ll * c * pp % mod);
}
return g;
}
int mid = (l + r) / 2;
return solve(l, mid) * solve(mid + 1, r);
}
void work() {
scanf("%d%d%d", &n, &t, &m);
rep(i, 1, t) { scanf("%d", &p[i]); }
vi f = solve(1, t);
// rep(i, 0, m * t) { printf("%d ", f[i]); }
vi h(1, 1);
int c = 1;
rep(j, 1, m * t) {
c = 1ll * c * (m * t - j + 1) % mod * inv(mod - j) % mod;
h.push_back(c);
}
h = h * inv(f);
// rep(i, 0, m * t) { printf("%d ", h[i]); }
int ret;
if (n <= m * t) {
ret = h[n];
} else {
reverse(f.begin(), f.end());
vi a(h.begin() + 1, h.end());
--n;
ret = linear(f, a, n);
}
printf("%d", ret);
}
int main() {
work();
return 0;
}