题目
正解
推式子题。
比赛时推了半天的生成函数最终推回了一个递推式?
容斥一下,答案为
∑
i
=
0
N
(
−
1
)
i
C
N
i
C
S
−
i
T
M
\sum_{i=0}^N(-1)^iC_N^iC_{S-iT}^M
∑i=0N(−1)iCNiCS−iTM
然后就是推式子:
=
∑
i
=
0
N
(
−
1
)
i
C
N
i
[
x
M
]
(
1
+
x
)
S
−
i
T
=
[
x
M
]
∑
i
=
0
N
(
−
1
)
i
C
N
i
(
1
+
x
)
(
n
−
i
)
T
(
1
+
x
)
S
−
n
T
=
[
x
M
]
(
(
1
+
x
)
(
n
−
i
)
T
−
1
)
N
(
1
+
x
)
S
−
n
T
=
[
x
M
−
N
]
(
(
1
+
x
)
(
n
−
i
)
T
−
1
x
)
N
(
1
+
x
)
S
−
n
T
=\sum_{i=0}^N(-1)^iC_N^i[x^M](1+x)^{S-iT}\\ =[x^M]\sum_{i=0}^N(-1)^iC_N^i(1+x)^{(n-i)T}(1+x)^{S-nT} \\ =[x^M]((1+x)^{(n-i)T}-1)^N(1+x)^{S-nT} \\ =[x^{M-N}](\frac{(1+x)^{(n-i)T}-1}{x})^N(1+x)^{S-nT}
=i=0∑N(−1)iCNi[xM](1+x)S−iT=[xM]i=0∑N(−1)iCNi(1+x)(n−i)T(1+x)S−nT=[xM]((1+x)(n−i)T−1)N(1+x)S−nT=[xM−N](x(1+x)(n−i)T−1)N(1+x)S−nT
里面的那项和右边的那项二项式展开,可以快速计算。
然后左边那项先
l
n
ln
ln,乘上指数之后再
e
x
p
exp
exp,即可以得到乘方。
代码
经过隔壁大佬指点,从此NTT少两个模法。
using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 524288
#define ll long long
#define mo 998244353
#define mo2 998244353ll*998244353ll
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
ll inv[N+10];
ll n,m,s,t;
int nN,re[N];
void setlen(int n){
int bit=0;
for (nN=1;nN<2*n;nN<<=1,++bit);
re[0]=0;
for (int i=1;i<nN;++i)
re[i]=re[i>>1]>>1|(i&1)<<bit-1;
}
void clear(ll A[],int n){
memset(A,0,sizeof(ll)*n);
}
void dft(ll A[],int flag){
static ll w[N];
for (int i=0;i<nN;++i)
if (i<re[i])
swap(A[i],A[re[i]]);
for (int i=1;i<nN;i<<=1){
ll wn=qpow(3,flag==1?(mo-1)/(2*i):mo-1-(mo-1)/(2*i));
w[0]=1;
for (int k=1;k<i;++k)
w[k]=w[k-1]*wn%mo;
for (int j=0;j<nN;j+=i<<1){
ll wnk=1;
for (int k=0;k<i;++k){
ll x=A[j+k],y=A[j+k+i]*w[k];
A[j+k]=(x+y)%mo;
A[j+k+i]=(x-y+mo2)%mo;
}
}
}
if (flag==-1){
ll invn=inv[nN];
for (int i=0;i<nN;++i)
A[i]=A[i]*invn%mo;
}
}
void multi(ll c[],ll a[],ll b[],int n){
static ll A[N],B[N];
setlen(n);
clear(A,nN);
for (int i=0;i<n;++i)
A[i]=a[i];
dft(A,1);
if (a!=b){
clear(B,nN);
for (int i=0;i<n;++i)
B[i]=b[i];
dft(B,1);
for (int i=0;i<nN;++i)
c[i]=A[i]*B[i]%mo;
}
else{
for (int i=0;i<nN;++i)
c[i]=A[i]*A[i]%mo;
}
dft(c,-1);
for (int i=n;i<nN;++i)
c[i]=0;
}
void getinv(ll B[],ll A[],int n){
static ll t[N],t1[N];
int nn=1;
for (;nn<n;nn<<=1);
clear(B,nn);
B[0]=qpow(A[0]);
for (int i=2;i<=nn;i<<=1){
setlen(i);
clear(t,nN),clear(t1,nN);
for (int j=0;j<i;++j)
t[j]=B[j],t1[j]=A[j];
dft(t,1),dft(t1,1);
for (int j=0;j<nN;++j)
t[j]=t[j]*t[j]%mo*t1[j]%mo;
dft(t,-1);
for (int j=0;j<i;++j)
B[j]=(2*B[j]-t[j]+mo)%mo;
}
for (int i=n;i<=nn;++i)
B[i]=0;
}
void getln(ll B[],ll A[],int n){
static ll A_[N],t[N];
for (int i=1;i<n;++i)
A_[i-1]=A[i]*i%mo;
A_[n-1]=0;
getinv(t,A,n);
multi(B,A_,t,n);
for (int i=n-1;i>=1;--i)
B[i]=B[i-1]*inv[i]%mo;
B[0]=0;
}
void getexp(ll B[],ll A[],int n){
static ll t[N];
B[0]=1;
int m=0;
for (;1<<m<=n;++m){
getln(t,B,1<<m);
t[0]=(1+A[0]-t[0]+mo)%mo;
for (int j=1;j<1<<m;++j)
t[j]=(A[j]-t[j]+mo)%mo;
multi(B,B,t,1<<m+1);
}
getln(t,B,n);
t[0]=(1+A[0]-t[0]+mo)%mo;
for (int j=1;j<n;++j)
t[j]=(A[j]-t[j]+mo)%mo;
multi(B,B,t,n);
}
void getpow(ll A[],ll k,int n){
static ll t[N];
ll c=0,d=0;
for (int i=0;i<n;++i)
if (A[i]){
c=A[i],d=i;
break;
}
int invc=qpow(c);
for (int i=d;i<n;++i)
A[i-d]=A[i]*invc%mo;
for (int i=n-d;i<n;++i)
A[i]=0;
getln(t,A,n);
k%=mo;
for (int i=0;i<n;++i)
t[i]=t[i]*k%mo;
getexp(A,t,n);
c=qpow(c,k);
d*=k;
for (int i=n-1;i>=d;--i)
A[i]=A[i-d]*c%mo;
for (int i=min((ll)n-1,d-1);i>=0;--i)
A[i]=0;
}
ll F[N],G[N];
void initC(ll F[],ll k,int n){
ll c=1;
F[0]=1;
for (int i=1;i<=n;++i){
c=c*((k-i+1)%mo)%mo*inv[i]%mo;
F[i]=c;
}
}
int main(){
freopen("sum.in","r",stdin);
freopen("sum.out","w",stdout);
inv[0]=1,inv[1]=1;
for (int i=2;i<=N;++i)
inv[i]=(ll)(mo-mo/i)*inv[mo%i]%mo;
scanf("%lld%lld%lld%lld",&s,&t,&n,&m);
// F[0]=1,F[1]=1;
// getpow(F,t,m-n+2);
initC(F,t,m-n+1);
F[0]--;
for (int i=0;i<=m-n;++i)
F[i]=F[i+1];
F[m-n+1]=0;
getpow(F,n,m-n+1);
G[0]=1,G[1]=1;
initC(G,s-n*t,m-n);
// getpow(G,s-n*t,m-n+1);
multi(F,F,G,m-n+1);
printf("%lld\n",F[m-n]);
return 0;
}