题目
给定n,m,求长为n(n<=1e7)且满足以下条件的序列a的个数,答案对998244353取模
①0<=a1<=a2<=...<=an<=m
②a_{i}%3与a_{i+1}%3不同,即相邻项模3的余数不同
思路来源
megurine烟花佬
第一层枚举的是 a1 % 3 的值,
第二层循环枚举的是一共 (a[i + 1] - a[i]) % 3 的值中2的个数,
一共 n - 1个差分数,然后再加上m的值可能不需要用完
也就是说余数,所以一共 n + 1个数,最后一个数是m - a[n]
题解
考虑构造差分序列b,b1=a1,b2=a2-a1,...
则最终b2到bn中不能有含0的项,只能为1或2,
枚举a1%3的值0,1,2,枚举后n-1项里有x项%3=2,则其他项%3=1
其余的相邻项之间的增量,只能用3的倍数来填,
相当于a0=0,a_{n+1}=m,中间n个数,共计n+2个数,
差分数组b共n+1项,每一项delta都是一个>=0的3的倍数
(b_{n+1}即最后一项可能不是,但只关心最后一项/3*3的值)
n+1个人平分总量(m-(n-1)-c2-a0)/3,每个人允许为0
插空法知x个人平分y每个人>=0的方案数为C(x+y-1,x-1),代入即可
心得
n+1个数,可能不需要用完,
这个trick,感觉是典中典,之前好像也是在abc遇到过
代码
#include <bits/stdc++.h>
using namespace std;
const int N=2e7+10,mod=998244353;
int n,m,ans;
int Finv[N],fac[N],inv[N];
int modpow(int x,int n,int mod){
int res=1;
for(;n;x=1ll*x*x%mod,n>>=1)
if(n&1)res=1ll*res*x%mod;
return res;
}
void init(int n){ //n<N
inv[1]=1;
for(int i=2;i<=n;++i)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
fac[0]=Finv[0]=1;
for(int i=1;i<=n;++i)fac[i]=1ll*fac[i-1]*i%mod,Finv[i]=1ll*Finv[i-1]*inv[i]%mod;
//Finv[n]=modpow(fac[n],mod-2,mod);
//for(int i=n-1;i>=1;--i)Finv[i]=1ll*Finv[i+1]*(i+1)%mod;
}
int C(int n,int m){
if(m<0||m>n)return 0;
return 1ll*fac[n]*Finv[n-m]%mod*Finv[m]%mod;
}
int main(){
init(N-5);
scanf("%d%d",&n,&m);
for(int a0=0;a0<3;++a0){
if(m-a0-(n-1)<0)continue;
for(int c2=0;c2<=n-1 && c2<=m-a0-(n-1);++c2){
ans=(ans+1ll*C(n-1,c2)*C((m-a0-(n-1)-c2)/3+n,n)%mod)%mod;// n+1个数 n个空位
}
}
printf("%d\n",ans);
return 0;
}
/*
3 1
3 2
3 7
3 3
3 4
3 5
0
1
4
8
14
*/