题面
解题思路
该题题目描述是有问题的,根据其题解可以分析出,该k维物体每一位的大小是相同的。
令
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示已经套了
i
i
i层,第
i
i
i层的宽度是
j
j
j。
有如下转移
d
p
[
i
]
[
j
]
=
∑
i
=
1
n
d
p
[
i
]
[
j
−
1
]
∗
(
n
−
i
+
1
)
k
dp[i][j]=\sum_{i=1}^ndp[i][j - 1]*(n-i+ 1) ^k
dp[i][j]=∑i=1ndp[i][j−1]∗(n−i+1)k。复杂度
O
(
n
2
d
)
O(n^2d)
O(n2d)。
考虑将转移系数写成矩阵形式,复杂度
O
(
n
3
l
o
g
2
d
)
O(n^3log_2d)
O(n3log2d)。
我们发现系数矩阵
A
A
A是一个上三角矩阵,且有如下性质:
A
[
i
]
[
j
]
=
A
[
i
−
1
]
[
j
−
1
]
A[i][j]=A[i-1][j-1]
A[i][j]=A[i−1][j−1]。
如此只要运算一行,就能够还原出整个矩阵。
考虑到是任意模数的傅里叶变换,要使用MTT或者三模数NTT,然后使用CRT合并。
使用CRT时应当预处理逆元来保证不会TLE。
复杂度
O
(
n
l
o
g
2
n
l
o
g
2
d
)
。
O(nlog_2nlog_2d)。
O(nlog2nlog2d)。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
typedef __int128 Int;
const int N = 2e5 + 100;
Int qpow(Int x, Int n, Int MOD) {
int res = 1;
while (n > 0) {
if (n & 1) res = 1LL * res * x % MOD;
x = 1LL * x * x % MOD;
n /= 2;
}
return res;
}
struct Poly {
int MOD, G;
int a[N], b[N], c[N], rev[N];
int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
int mul(int a, int b) { return 1LL * a * b % MOD; }
Poly(int MOD, int G) :MOD(MOD), G(G) {};
void ntt(int *a, int n, int op) {
for (int i = 0; i < n; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
int gn = qpow(G, (MOD - 1) / (i << 1), MOD);
for (int j = 0; j < n; j += (i << 1)) {
for (int k = 0, g = 1; k < i; k++, g = mul(g, gn)) {
int x = a[j + k], y = mul(g, a[i + j + k]);
a[j + k] = add(x, y);
a[i + j + k] = add(x, MOD - y);
}
}
}
if (op == 1) return;
reverse(a + 1, a + n);
int inv = qpow(n, MOD - 2, MOD);
for (int i = 0; i < n; i++) a[i] = mul(a[i], inv);
}
void getMul(int *aa, int n, int *bb, int m, int *cc, int q) {
int p = 1, l = 0;
while (p < n + m) p <<= 1, l++;
for (int i = 0; i < p; i++) a[i] = b[i] = c[i] = rev[i] = 0;
for (int i = 0; i < p; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
for (int i = 0; i < n; i++) a[i] = aa[i];
for (int i = 0; i < m; i++) b[i] = bb[i];
ntt(a, p, 1); ntt(b, p, 1);
for (int i = 0; i < p; i++) c[i] = mul(a[i], b[i]);
ntt(c, p, -1);
for (int i = 0; i < q; i++) cc[i] = c[i];
}
}f1(998244353, 3), f2(7340033, 3), f3(104857601, 3);
const Int m1 = 998244353, m2 = 7340033, m3 = 104857601;
int k, n, d, x;
int co[N], res[N], r1[N], r2[N], r3[N];
Int c12 = m2 * qpow(m2, m1 - 2, m1), c13 = m3 * qpow(m3, m1 - 2, m1);
Int c21 = m1 * qpow(m1, m2 - 2, m2), c23 = m3 * qpow(m3, m2 - 2, m2);
Int c31 = m1 * qpow(m1, m3 - 2, m3), c32 = m2 * qpow(m2, m3 - 2, m3);
Int CRT(Int ans1, Int ans2, Int ans3) {
Int M = m1 * m2 * m3;
Int ans = (c12 * c13 % M * ans1 + c21 * c23 % M * ans2 + c31 * c32 % M * ans3) % M;
return ans;
}
void solve(int *a, int *b) {
f1.getMul(a, n, b, n, r1, n);
f2.getMul(a, n, b, n, r2, n);
f3.getMul(a, n, b, n, r3, n);
for (int i = 0; i < n; i++) a[i] = CRT(r1[i], r2[i], r3[i]) % x;
}
int main() {
//freopen("0.txt", "r", stdin);
scanf("%d%d%d%d", &k, &n, &d, &x);
for (int i = 0; i < n; i++) co[i] = qpow(i + 1, k, x);
res[0] = 1;
while (d > 0) {
if (d & 1) solve(res, co);
solve(co, co);
d /= 2;
}
int ans = 0;
for (int i = 0; i < n; i++) ans = (ans + res[i]) % x;
printf("%d\n", ans);
return 0;
}