今晚hu老师发了我一道题目,好久没碰过了,没想到写了一下还ok。
题意是
时间限制: 1.0 秒
空间限制: 256 MB
输入m,n,x,y表示,序列长度是2*n,序列中最大数是m,合法序列定义:前n个数不下降,后n个数不上升,x和y位置的数相等。求和法序列的数目ans%998244353后的值。(其中 1≤x<y≤2n 1≤n,m≤10^5)
样例1输入
3 2 1 3
样例1输出
10
样例2输入
1000 1000 535 1477
样例2输出
295916566
思路:根据x,y的的位置分三种情况, 然后造几个简单的数据算一下,会发现递推公式写出来是一些组合数,因为n,m是1e5,所以需要预处理一下,dp预处理,然后统计结果就好了。
(因为不知道题目来源,所以只测试了样例数据,代码可能还有瑕疵)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5;
const ll mod=998244353;
ll qmod(ll x,ll p)
{
ll ans=1;
while(p)
{
if(p&1) ans=ans*x%mod;
x=x*x%mod;
p>>=1;
}
return ans;
}
ll fac[N+10],inv[N+10];
void init()
{
fac[0]=1;
for(int i=1;i<=N;i++)
{
fac[i]=fac[i-1]*i%mod;
if(fac[i]==0)
fac[i]=fac[i]+mod;
}
inv[N]=qmod(fac[N],mod-2);
for(int i=N-1;i>=0;i--)
inv[i]=(i+1)*inv[i+1]%mod;
}
ll C(ll n,ll m)
{
if(m>n) return 0;
return fac[n]*inv[n-m]%mod*inv[m]%mod;
}
ll n,m,x,y;
ll dp[5][N+5],ans=0;
int main()
{
init();
scanf("%d%d%d%d",&m,&n,&x,&y);
if(y<=n)//1<x<y<=n
{
for(int val=1;val<=m;val++)
{
for(int i=1;i<=4;i++)
{
int len;
if(i==1) len=x-1;
if(i==2) len=y-x;
if(i==3) len=n-y;
if(i==4) len=n;
dp[i][val]=dp[i][val-1]+C(len+val-2,val-1);
dp[i][val]%=mod;
}
}
for(int val=1;val<=m;val++)
{
ll tmp=1;
for(int i=1;i<=4;i++)
{
int k,len;
if(i==1) k=val;
if(i==2) k=1;
if(i==3) k=m-val+1;
if(i==4) k=m;
if(dp[i][k]>0)
tmp=tmp*dp[i][k]%mod;
}
(ans+=tmp)%=mod;
}
}
if(x>n)//n<x<y<=2*n
{
for(int val=1;val<=m;val++)
{
for(int i=1;i<=4;i++)
{
int len;
if(i==1) len=n;
if(i==2) len=x-n-1;
if(i==3) len=y-x;
if(i==4) len=2*n-y;
dp[i][val]=dp[i][val-1]+C(len+val-2,val-1);
dp[i][val]%=mod;
}
}
for(int val=1;val<=m;val++)
{
ll tmp=1;
for(int i=1;i<=4;i++)
{
int k,len;
if(i==1) k=m;
if(i==2) k=m-val+1;
if(i==3) k=1;
if(i==4) k=val;
if(dp[i][k]>0)
tmp=tmp*dp[i][k]%mod;
}
(ans+=tmp)%=mod;
}
}
if(x<=n&&y>n)//x<=n<y<=2*n
{
for(int val=1;val<=m;val++)
{
for(int i=1;i<=4;i++)
{
int k,len;
if(i==1) len=x-1;
if(i==2) len=n-x;
if(i==3) len=y-n-1;
if(i==4) len=2*n-y;
dp[i][val]=dp[i][val-1]+C(len+val-2,val-1);
dp[i][val]%=mod;
}
}
for(int val=1;val<=m;val++)
{
ll tmp=1;
for(int i=1;i<=4;i++)
{
int k,len;
if(i==1) k=val;
if(i==2) k=m-val+1;
if(i==3) k=m-val+1;
if(i==4) k=val;
if(dp[i][k]>0)
tmp=tmp*dp[i][k]%mod;
}
(ans+=tmp)%=mod;
}
}
printf("%lld\n",ans);
return 0;
}