Description
Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望
,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。
Input
一行三个数,n,m,p。
1<=n<=10^9,1<=m<=2×10^7,1<=p<=100
Output
一行一个数,满足Alice的要求的序列数量,答案对20170408取模。
Sample Input
3 5 3
Sample Output
33
暴力DP:
要求总和为p的倍数,容易想到将p的剩余系作为一维状态,f[i][j]就表示长度为i的序列,总和mod p为j 的 方案数
枚举每次选的数,进行转移即可
而对于第二个条件,根据容斥原理,先算出一般情况,再减去没有质数的情况即可
但这样只能拿到20分左右
优化:
先对于枚举每次选的数进行优化,把每个数放到c数组c[i]表示1-m中模p为i的数的个数
转移变为
f[i+1][(j+k) % p] = f[i][j] * c[k]
这里考虑矩阵乘法加速,每次转移构造成一个初等矩阵s
s[i][j] 为 c[(j-i+p)%p]
向量f[i] 经过 s 变换后成为向量f[i+1]
大量的转移用矩阵快速幂先算出
冗长代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#define File(x) "count."#x
#define For(i,s,e) for(int i=(s); i<=(e); i++)
#define Rep(i,s,e) for(int i=(s); i>=(e); i--)
using namespace std;
const int N=500+3,M=2*10000000,Mod=20170408;
struct Matrix{
int n,m,a[N][N];
void clear(){
n=m=0; memset(a,0,sizeof a);
}
void init(int x, int y){
this->clear();
n=x; m=y;
For(i,0,n-1) a[i][i]=1;
}
Matrix operator * (const Matrix &o) const {
Matrix ret; ret.clear();
ret.n=n, ret.m=o.m;
For(i,0,ret.n-1) For(j,0,ret.m-1) For(k,0,m-1){
ret.a[i][j]+=(1LL*a[i][k]*o.a[k][j])%Mod;
ret.a[i][j]%=Mod;
}
return ret;
}
}f,s;
int n,m,p,pr[M/3],c[N],ans1,ans2;
bool notp[M];
void getPrime()
{
For(i,2,m){
if(!notp[i]) pr[++pr[0]]=i;
for(int j=1; j<=pr[0] && (long long)(pr[j]*i)<=m; j++){
int tmp=pr[j]*i;
notp[tmp]=1;
if(!(i%pr[j])) break;
}
}
}
void matrixPow(int b)
{
Matrix ret; ret.init(p,p);
while(b){
if(b&1) ret=ret*s;
s=s*s;
b>>=1;
}
f=f*ret;
}
int main()
{
freopen(File(in),"r",stdin);
freopen(File(out),"w",stdout);
// ios::sync_with_stdio(false);
cin>>n>>m>>p;
For(i,0,p-1) c[i]=(m/p);
For(i,1,m%p) c[i]++;
f.clear();
f.n=1; f.m=p;
For(i,0,p-1) f.a[0][i]=c[i];
s.clear();
s.n=s.m=p;
For(i,0,s.n-1) For(j,0,s.n-1) s.a[i][j]=c[(j-i+p)%p];
matrixPow(n-1);
ans1=f.a[0][0];//一般情况
getPrime();
For(i,1,pr[0]) c[pr[i]%p]--;
f.clear();
f.n=1; f.m=p;
For(i,0,p-1) f.a[0][i]=c[i];
s.clear();
s.n=s.m=p;
For(i,0,s.n-1) For(j,0,s.n-1) s.a[i][j]=c[(j-i+p)%p];
matrixPow(n-1); ans2=f.a[0][0];//没有质数的情况
ans1=ans1-ans2;
if(ans1<0) ans1+=Mod;
else if(ans1>Mod) ans1-=Mod;
printf("%d\n",ans1);
return 0;
}