题面
解题思路
考虑如何求解阶乘,跟这篇题解类似,是求阶乘的最低十八位。组合数能够写成阶乘的形式。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
#pragma GCC optimize(2)
typedef long long ll;
typedef __int128 Int;
Int add(Int a, Int b, Int mod) { return a + b >= mod ? a + b - mod : a + b; }
Int mul(Int a, Int b, Int mod) { return 1LL * a * b % mod; }
Int qpow(Int x, Int n, Int mod) { Int r = 1; for (; n; n >>= 1, x = mul(x, x, mod)) if (n & 1) r = mul(x, r, mod); return r; }
void exgcd(Int a, Int b, Int &x, Int &y) { b == 0 ? (x = 1, y = 0) : (exgcd(b, a % b, y, x), y -= a / b * x); }
Int rev(Int a, Int p) { Int x, y; exgcd(a, p, x, y); return (x + p) % p; }
struct node {
Int s2[110][110];
void init(Int k, Int mod) {
s2[0][0] = 1;
for (int i = 1; i <= k + 10; i++) {
for (int j = 1; j <= i; j++) {
s2[i][j] = add(s2[i - 1][j - 1], mul(j, s2[i - 1][j], mod), mod);
}
}
}
Int get(Int n, Int k, Int mod) {
if (n == 0) return 0;
Int res = 0;
for (Int i = 1; i <= k && i <= n; i++) {
Int sum = s2[k][i], flag = true;
for (Int j = n - i + 1; j <= n + 1; j++) {
if (j % (i + 1) == 0 && flag) sum = mul(sum, j / (i + 1), mod), flag = false;
else sum = mul(sum, j, mod);
}
res = add(res, sum, mod);
}
return res;
}
Int g(Int t, Int d, Int p, Int k, Int mod) {
if (t <= k) {
Int res = 1;
for (Int i = 0; i <= t; i++) res = mul(res, add(mul(i, p, mod), d, mod), mod);
return res;
}
else {
static Int dp[110], cd[110], cc[110], cp[110], cnt[110];
cd[k - 1] = qpow(d, t - k + 2, mod);
for (Int i = k - 2; i >= 0; i--) cd[i] = mul(d, cd[i + 1], mod);
cc[0] = 1;
for (Int i = 1; i < k; i++) cc[i] = get(t, i, mod);
cp[0] = 1;
for (Int i = 1; i < k; i++) cp[i] = mul(cp[i - 1], p, mod);
cnt[0] = 0;
for (Int i = 1; i < k; i++) cnt[i] = 0;
dp[0] = 1; Int ma = 0;
for (Int i = 1; i < k; i++) {
dp[i] = 0;
for (int j = 1, f = 1; j <= i; j++, f = -f) {
Int r = mul(cc[j], mul(dp[i - j], cp[ma - cnt[i - j]], mod), mod);
if (f > 0) dp[i] = add(dp[i], r, mod);
else dp[i] = add(dp[i], mod - r, mod);
}
Int x = i; cnt[i] = ma;
while (x % p == 0) x /= p, cnt[i]++, ma++;
dp[i] = mul(dp[i], rev(x, mod), mod);
}
Int res = 0;
for (int i = 0; i < k; i++) res = add(res, mul(cp[i - cnt[i]], mul(cd[i], dp[i], mod), mod), mod);
return res;
}
}
Int f(Int n, Int p, Int k, Int mod) {
if (n < p) {
Int res = 1;
for (int i = 2; i <= n; i++) res = mul(res, i, mod);
return res;
}
else {
Int res = f(n / p, p, k, mod);
for (int i = 1; i < p; i++) res = mul(res, g(n / p + (n % p >= i ? 0 : -1), i, p, k, mod), mod);
return res;
}
}
Int ans, cnt;
void Fac(Int n, Int p, Int k, Int mod, Int op) {
if (op == 1) ans = mul(ans, f(n, p, k, mod), mod);
else ans = mul(ans, rev(f(n, p, k, mod), mod), mod);
while (n > 0) cnt += n / p * op, n /= p;
}
void solve(Int n, Int m, Int p, Int k, Int mod) {
init(k, mod);
ans = 1, cnt = 0;
Fac(n + m, p, k, mod, 1);
Fac(n + 1, p, k, mod, -1); Fac(m, p, k, mod, -1);
Int x = n - m + 1;
while (x % p == 0) cnt++, x /= p;
ans = mul(ans, x, mod);
}
}a2, a5;
ll n, m;
int main() {
//freopen("0.txt", "r", stdin);
scanf("%lld%lld", &m, &n);
if (m == 0) puts("1");
else if (m > n) puts("0");
else {
Int n2 = qpow(2, 18, 1e20), n5 = qpow(5, 18, 1e20), nn = n2 * n5;
a2.solve(n, m, 2, 18, n2);
a5.solve(n, m, 5, 18, n5);
Int r2 = a2.ans, r5 = a5.ans;
Int c2 = a2.cnt, c5 = a5.cnt;
r5 = mul(r5, qpow(rev(2, n5), c2, n5), n5);
r2 = mul(r2, qpow(rev(5, n2), c5, n2), n2);
Int t5 = Int(n2) * rev(n2, n5) % nn * r5 % nn;
Int t2 = Int(n5) * rev(n5, n2) % nn * r2 % nn;
Int r = (t5 + t2) % nn;
r = mul(r, qpow(2, c2, nn), nn);
r = mul(r, qpow(5, c5, nn), nn);
printf("%lld\n", ll(r));
}
return 0;
}