题目大意
给出n+m个问题,其中n个答案是yes,m个no,每一次回答问题,你选择回答yes或no,你能知道当前局面未回答问题的yes和no的数量。问最后你期望回答对多少个问题。
n,m≤5e5,在模998244353意义下计算答案。
解题思路
设剩下a个yes,b个no的局面是(a,b)
首先手玩一下这个问题,很容易列出一个dp转移式,然后发现最优策略:对于(a,b),我们肯定是猜数量大的那一个(若a>b,猜yes),相等就随便猜一个。
我们画一个网格图,左上角是(n,m),右下角是(0,0),那么一条从(n,m)到(0,0),只走右和下的路径就代表了一种情况。我们根据最优策略可以知道,一个点只有往某一个方向走才会猜对1道,很容易发现这个网格图的哪些边是有贡献的。那么很显然的思路就是,统计每一条边会被经过多少次,然后除以总方案数即可。
考虑另一个思路。我们做一条y=x的直线,对于路径(a,b)->(0,0),假如不经过直线,那他的贡献一定是max(a,b),就是你一直猜大的那个。
再考虑一般的路径(n,m)->(0,0),忽略掉形如(i,i)的点及其出边,他被拆分成若干段不经过直线的路径,这些路径贡献和加起来一定是max(n,m)。(点(i,i)有可能走到(i-1,i),(i,i-1),他们到(0,0)的路径的贡献是一样的)
剩下(i,i)们没有考虑,我们知道它有1/2的几率猜对,那么我们只需要分别统计能到(i,i)的方案数即可。最后乘1/2.
所以整个算法只用算(i,i)的贡献最后加max(n,m)即可。
代码
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
typedef long long ll;
const int N=1e6+5,M=2e6+5,mo=998244353;
int fac[N],rev[N],n,m,ans,i;
int ksm(int x,int y)
{
int ret=1;
while (y)
{
if (y&1) ret=1ll*ret*x%mo;
y>>=1;
x=1ll*x*x%mo;
}
return ret;
}
int c(int n,int m)
{
return 1ll*fac[m]*rev[n]%mo*rev[m-n]%mo;
}
int main()
{
freopen("t1.in","r",stdin);
//freopen("t1.out","w",stdout);
scanf("%d %d\n",&n,&m);
if (n<m) swap(n,m);
fac[0]=1;
fo(i,1,n+m) fac[i]=1ll*fac[i-1]*i%mo;
rev[n+m]=ksm(fac[n+m],mo-2);
fd(i,n+m,1) rev[i-1]=1ll*rev[i]*i%mo;
fo(i,1,m)
{
ans=(ans+1ll*c(i,i*2)*c(n-i,n+m-2*i))%mo;
}
ans=1ll*ans*ksm(2*c(n,n+m),mo-2)%mo;
ans=(ans+n+mo)%mo;
printf("%d",ans);
}