链接
https://www.luogu.org/problem/show?pid=3301
组合数取模
有必要在这里插入对组合数取模的介绍。
欲求
Cmn mod p
如果p是比较小的素数,直接lucas定理求
ll C(ll n, ll m, ll p)
{
if(m>n)return 0;
return fact[n]*inv(fact[n-m],p)%p*inv(fact[m],p)%p;
}
ll lucas(ll n, ll m, ll p)
{
if(m==0)return 1;
return lucas(n/p,m/p,p)*C(n%p,m%p,p)%p;
}
如果p是合数但能够分解成单个素数的乘积,分别求模各个素数因子意义下的
Cmn
然后
CRT
合并即可。
如果p是普通合数,但是
p=pq11pq22...pqkk
,且
pqii
比较小,这就是今天要重点讨论的问题。
首先肯定是分别求出然后再
CRT
合并,那么问题转成如何求
Cmn mod pt
考虑公式
Cmn=n!m!(n−m)!
在这里 pt 肯定是合数,直接求逆元的话不保证互质。但是我们知道任何一个数 x ,都能表示成
这就是解决这道题目的关键, e 的部分由于和
现在问题分成两部分,对于一个 n!=e×pk 分别怎么求 e 和
n! mod p=1×2×3×...×n (mod p)
我们把 p 的倍数提出来,
就是这样啦。
题解
第二种限制就是高中数学裸题….直接插板法,用
M
减去
第一种限制。
n1=8
告诉我们这个算法可能是阶乘级别或者指数级别的,因此容易想到容斥,算出
>
限制的,然后容斥算就好了。
哦对了,原题中每个测试点的模数是告诉你的,具体可以看程序中的特判。
组合数取膜部分上面已经说过了。
总的复杂度是
代码
//容斥+卢卡斯定理+中国剩余定理
#include <cstdio>
#include <algorithm>
#define ll long long
#define maxn 100000
using namespace std;
ll T, P, N, n1, n2, M, fact[maxn], ans, A[maxn], a[maxn], m[maxn], tot, mi[maxn], k,
pp;
ll pow(ll a, ll b, ll p)
{
ll ans, t;
for(t=a,ans=1;b;b>>=1,t=t*t%p)if(b&1)ans=ans*t%p;
return ans;
}
void exgcd(ll a, ll b, ll &x, ll &y)
{
if(!b){x=1,y=0;return;}
ll xx, yy;
exgcd(b,a%b,xx,yy);
x=yy, y=xx-a/b*yy;
}
ll inv(ll a, ll p)
{
ll x, y;
exgcd(a,p,x,y);
return (x+p)%p;
}
ll calcfact(ll n, ll p, ll &pt)
{
if(n==0)return 1;
ll t1, t2, i;
t1=pow(fact[p-1],n/p,p);
t1=t1*fact[n%p]%p;
pt+=n/pp;
t2=calcfact(n/pp,p,pt);
return t1*t2%p;
}
ll C(ll n, ll m, ll p)
{
ll A, B, C, ptA=0, ptB=0, ptC=0, pt;
A=calcfact(n,p,ptA);
B=calcfact(m,p,ptB);
C=calcfact(n-m,p,ptC);
pt=ptA-ptB-ptC;
return A*inv(B,p)%p*inv(C,p)%p*pow(pp,pt,p)%p;
}
void dfs(ll pos, ll sum, ll k, ll p)
{
if(M-sum<N)return;
if(pos>n1)
{
if(k&1)ans-=C(M-sum-1,N-1,p);
else ans+=C(M-sum-1,N-1,p);
return;
}
dfs(pos+1,sum,k,p);
dfs(pos+1,sum+A[pos],k+1,p);
}
ll calc(ll p)
{
ans=0;
dfs(1,0,0,p);
return ((ans%p)+p)%p;
}
ll gcd(ll a, ll b){return !b?a:gcd(b,a%b);}
ll crt(ll *a, ll *m, ll n)
{
ll M=1, ans=0, i;
for(i=1;i<=n;i++)M*=m[i];
for(i=1;i<=n;i++)ans=(ans+a[i]*(M/m[i])%M*inv(M/m[i],m[i]))%M;
return ans;
}
void work()
{
ll i, j, t, p;
scanf("%lld%lld%lld%lld",&N,&n1,&n2,&M);
for(i=1;i<=n1+n2;i++)scanf("%lld",A+i);
for(i=n1+1;i<=n1+n2;i++)M-=A[i]-1;
for(i=1;i<=tot;i++)
{
if(m[i]==125)pp=5,k=3;
else if(m[i]==343)pp=7,k=3;
else if(m[i]==10201)pp=101,k=2;
else pp=m[i],k=1;
p=m[i];
fact[0]=1;
for(j=1;j<p;j++)if(j%pp!=0)fact[j]=fact[j-1]*j%p;else fact[j]=fact[j-1];
a[i]=calc(p);
// printf(" a=%lld\n",a[i]);
}
printf("%lld\n",crt(a,m,tot));
}
int main()
{
ll T, P;
scanf("%lld%lld",&T,&P);
if(P==262203414){tot=5;m[1]=2;m[2]=3;m[3]=11;m[4]=397;m[5]=10007;}
if(P==437367875){tot=3;m[1]=125;m[2]=343;m[3]=10201;}
if(P==10007){tot=1;m[1]=10007;}
while(T--)work();
return 0;
}