题目链接:BZOJ 4338
题目大意:用数字1~k填一个n*m的表格,每种数字可用任意次,要求每行数字1~m列单调不减,任意两行不完全相同,求方案数对P取模的值。
题解:扩展Lucas+CRT模板题,板子还不是太熟悉,贴到这里方便复习,有空回来加点注释。最后答案的式子比较容易得到,是 AnCmm+k−1 mod P A C m + k − 1 m n m o d P 。
code
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 100005
using namespace std;
typedef long long ll;
inline int read()
{
char c=getchar(); int num=0,f=1;
while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
return num*f;
}
int n,m,k,mod,d[15],D[15],mo,phi[15],ans[15],num,jc[N],inv[N];
inline int add(int x,int y,int p) { x+=y; if (x>=p) x-=p; return x; }
inline int ksm(int a,int b,int p)
{
int ret=1;
for (;b;b>>=1,a=1ll*a*a%p)
if (b&1) ret=1ll*ret*a%p;
return ret;
}
struct newnum{
int val,tmp;
newnum (int _val,int _tmp) { val=_val; tmp=_tmp; }
newnum operator * (const newnum U) const
{
return newnum(1ll*val*U.val%D[mo],tmp+U.tmp);
}
newnum operator / (const newnum U) const
{
return newnum(1ll*val*ksm(U.val,phi[mo]-1,D[mo])%D[mo],tmp-U.tmp);
}
};
inline void dvd(int p)
{
for (int i=2;1ll*i*i<=p;i++)
if (p%i==0)
{
d[++num]=i,D[num]=1;
while (p%i==0) p/=i,D[num]*=i;
phi[num]=D[num]/d[num]*(d[num]-1);
if (p==1) break;
}
if (p!=1) d[++num]=D[num]=p,phi[num]=p-1;
}
inline newnum getjc(int x)
{
newnum now=newnum(1,0);
if (x>=d[mo]) now=now*getjc(x/d[mo]),now.tmp+=x/d[mo];
if (x>=D[mo]) now=now*newnum(ksm(jc[D[mo]-1],x/D[mo],D[mo]),0);
now=now*newnum(jc[x%D[mo]],0);
return now;
}
inline newnum getC(int n,int m)
{
return getjc(n)/getjc(m)/getjc(n-m);
}
inline int getP(int n,int m,int p)
{
int now=n;
for (int i=2;i<=m;i++) now=add(now,p-1,p),n=1ll*n*now%p;
return n;
}
void solve(int x)
{
mo=x,jc[0]=1; int p=D[x];
if (d[mo]==p&&p>m)
{
inv[1]=1;
for(int i=2;i<=m;i++) inv[i]=p-1ll*inv[p%i]*(p/i)%p;
ans[mo]=1; int now=k+m-1;
for (int i=1;i<=m;i++)
ans[mo]=1ll*ans[mo]*now%p*inv[i]%p,now=add(now,p-1,p);
ans[mo]=getP(ans[mo],n,p);
}
else
{
for (int i=1;i<p;i++)
{
jc[i]=jc[i-1];
if (i%d[mo]) jc[i]=1ll*jc[i]*i%p;
}
newnum now=getC(k+m-1,m);
ans[mo]=1ll*now.val*ksm(d[mo],now.tmp,p)%p;
ans[mo]=getP(ans[mo],n,p);
}
}
inline int CRT()
{
int ret=0;
for(int i=1;i<=num;i++)
{
mo=i;
ret=add(ret,1ll*(mod/D[i])*ksm(mod/D[i],phi[i]-1,D[i])%mod*ans[i]%mod,mod);
}
return ret;
}
int main()
{
n=read(); m=read(); k=read(); mod=read();
dvd(mod);
for (int i=1;i<=num;i++) solve(i);
printf("%d",CRT());
return 0;
}
另一份简洁一些的模板(BZOJ_3129_方程)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
inline int read()
{
char c=getchar(); int num=0,f=1;
while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
return num*f;
}
int T,P,bin[15],n,n1,n2,m,a[15],p[7],powp[7],jc[10205],ans[7],anss;
int ksm(int a,int b,int p)
{
int ret=1;
for (;b;b>>=1,a=1ll*a*a%p)
if (b&1) ret=1ll*ret*a%p;
return ret;
}
int exgcd(int a,int b,int &x,int &y)
{
if (!b) x=1,y=0;
else exgcd(b,a%b,y,x),y-=(a/b)*x;
}
int inv(int a,int b)
{
if (!a) return 0;
int x=0,y=0; exgcd(a,b,x,y);
x=(x%b+b)%b;
return x;
}
int fac(int n,int pi,int pk)
{
if (n==0) return 1; int ret=1;
if (n/pk) ret=ksm(jc[pk],n/pk,pk);
ret=1ll*ret*jc[n%pk]%pk;
return 1ll*ret*fac(n/pi,pi,pk)%pk;
}
int comb(int n,int m,int pi,int pk)
{
if (n<m) return 0; int k=0;
int a=fac(n,pi,pk),b=fac(m,pi,pk),c=fac(n-m,pi,pk);
for (int i=n;i;i/=pi) k+=i/pi;
for (int i=m;i;i/=pi) k-=i/pi;
for (int i=n-m;i;i/=pi) k-=i/pi;
int ret=1ll*a*inv(b,pk)%pk*inv(c,pk)%pk*ksm(pi,k,pk)%pk;
return ret;
}
inline int add(int a,int b) { a+=b; if (a>=P) a-=P; return a; }
inline int sub(int a,int b) { a-=b; if (a<0) a+=P; return a; }
int main()
{
bin[0]=1; for (int i=1;i<=10;i++) bin[i]=bin[i-1]<<1;
T=read(); P=read(); int tmp=P;
for (int i=2;i*i<=tmp;i++)
if (tmp%i==0)
{
p[++p[0]]=i; powp[p[0]]=1;
while (tmp%i==0) powp[p[0]]*=i,tmp/=i;
}
if (tmp>1) p[++p[0]]=powp[p[0]]=tmp;
while (T--)
{
n=read(); n1=read(); n2=read(); m=read()-n; anss=0;
for (int i=1;i<=n1;i++) a[i]=read();
for (int i=1;i<=n2;i++) m-=read()-1;
if (m<0) { printf("0\n"); continue; }
for (int cur=1;cur<=p[0];cur++)
{
jc[0]=1; ans[cur]=0;
for (int j=1;j<=powp[cur];j++)
{
jc[j]=jc[j-1];
if (j%p[cur]) jc[j]=1ll*jc[j]*j%powp[cur];
}
for (int sta=0;sta<bin[n1];sta++)
{
int cnt=0,ret; tmp=m;
for (int i=1;i<=n1;i++)
if (sta&bin[i-1]) cnt++,tmp-=a[i];
if (tmp<0) continue;
ret=comb(tmp+n-1,tmp,p[cur],powp[cur]);
if (cnt&1) ans[cur]=sub(ans[cur],ret);
else ans[cur]=add(ans[cur],ret);
}
}
for (int cur=1;cur<=p[0];cur++)
{
int ret=1ll*ans[cur]*(P/powp[cur])%P*inv(P/powp[cur],powp[cur])%P;
anss=add(anss,ret);
}
printf("%d\n",anss);
}
return 0;
}