设m < n,答案为n + C(n + m + 1, m)
所以求出组合数就好了...
发现n * m最大为10^12,那么m最大为10^6。
将组合数用阶乘形式展开,为 (n + m + 1)! / n! / (m + 1)!。
将前两项约掉,得到
(n + 1) * (n + 2) * ... * (n + m + 1) / (m + 1)!
分子有m项,分母有m项,所以可以O(m)求出来。
因为有inv(x!, p) = inv(x! % p, p),所以求出分母模p的值之后再求逆元,那么只需要求一次逆元。
总复杂度为O(m + log(p))。
写出来之后发现跑了rank1...
#include <cstdio>
#include <algorithm>
using namespace std;
typedef unsigned long long ULL;
typedef long long LL;
const ULL p = 1000000007;
void exgcd(ULL a, ULL b, LL &x, LL &y) {
b ? (exgcd(b, a % b, y, x), y -= a / b * x) : (x = 1, y = 0);
}
ULL inv(ULL a) {
LL x, y;
exgcd(a, p, x, y);
if(x < 0) x += p;
return x % p;
}
int main() {
ULL n, m; scanf("%llu%llu", &n, &m);
if(m > n) swap(n, m); n += m + 1; n %= p;
ULL a = 1, b = 1;
for(int i = 1; i <= m; i++) {
a = (n - i + 1) * a % p;
b = i * b % p;
}
ULL ans = (a * inv(b) % p + n - m - 1) % p;
printf("%llu\n", ans);
return 0;
}