链接:
题意:
有
n+m
个询问,其中
n
个是YES,
题解:
显然最优策略是回答多的那个。
不妨设
n≥m
,将正确答案画在二维平面上,发现是
(n,m)
到
(0,0)
的一条路径。
画一条对角线
y=x
,我们发现如果路径不经过对角线那么答案就是
n
。
我们考虑计算答案的增量,发现在对角线处才会贡献有
代码:
#include <bits/stdc++.h>
#define xx first
#define yy second
#define mp make_pair
#define pb push_back
#define mset(x, y) memset(x, y, sizeof x)
#define mcpy(x, y) memcpy(x, y, sizeof x)
using namespace std;
typedef long long LL;
typedef pair <int, int> pii;
inline int Read()
{
int x = 0, f = 1, c = getchar();
for (; !isdigit(c); c = getchar())
if (c == '-')
f = -1;
for (; isdigit(c); c = getchar())
x = x * 10 + c - '0';
return x * f;
}
const int MAXN = 1000005;
const int mod = 998244353;
int n, m, ans, sum, fac[MAXN], inv[MAXN];
inline int C(int n, int m)
{
return 1LL * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int main()
{
#ifdef wxh010910
freopen("data.in", "r", stdin);
#endif
n = Read(), m = Read();
if (n < m)
swap(n, m);
ans = n;
fac[0] = fac[1] = inv[0] = inv[1] = 1;
for (int i = 2; i <= n + m; i ++)
fac[i] = 1LL * fac[i - 1] * i % mod, inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
for (int i = 2; i <= n + m; i ++)
inv[i] = 1LL * inv[i] * inv[i - 1] % mod;
for (int i = 1; i <= m; i ++)
sum = (1LL * C(i << 1, i) * C(n + m - (i << 1), n - i) + sum) % mod;
return printf("%d\n", (1LL * sum * inv[2] % mod * inv[n + m] % mod * fac[n] % mod * fac[m] + ans) % mod), 0;
}