Sum
时间限制: 1 Sec 内存限制: 128 MB题目描述
求有多少 n位十进制数是 p的倍数且每位之和小于等于 mi(mi=0,1,2,…,m−1,m),允许前导 0,答案对 998244353 取模。
输入
一行三个整数 n,p,m。(1≤n≤109,1≤p≤50.1≤m≤1000)
输出
输出一行 m+1个正整数,分别表示 mi=0,1,2,…,m−1,m时的答案。
样例输入
2 3 3
样例输出
1 1 1 5
分析:考虑dp,dp[i][(j*10+k)%p][x+k]+=dp[i-1][j][x];
然后考虑倍增;
dp[x∗2][(j∗10x+l)%p][a+b]=dp[x][j][a]∗dp[x][l][b];
第三维是一个卷积形式,因此可以NTT优化到MlogM;
那么总复杂度是O(log(n)*m*max(log(m),p*p));
代码:
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <algorithm> #include <climits> #include <cstring> #include <string> #include <set> #include <bitset> #include <map> #include <queue> #include <stack> #include <vector> #include <cassert> #include <ctime> #define rep(i,m,n) for(i=m;i<=(int)n;i++) #define inf 0x3f3f3f3f #define mod 998244353 #define vi vector<int> #define pb push_back #define mp make_pair #define fi first #define se second #define ll long long #define pi acos(-1.0) #define pii pair<int,int> #define sys system("pause") #define ls (rt<<1) #define rs (rt<<1|1) #define all(x) x.begin(),x.end() const int maxn=3e3+10; const int N=5e4+10; using namespace std; ll gcd(ll p,ll q){return q==0?p:gcd(q,p%q);} ll qmul(ll p,ll q,ll mo){ll f=0;while(q){if(q&1)f=(f+p)%mo;p=(p+p)%mo;q>>=1;}return f;} ll qpow(ll p,ll q){ll f=1;while(q){if(q&1)f=f*p%mod;p=p*p%mod;q>>=1;}return f;} int n,m,k,t,p; const int G=3; int tmp[maxn],wn[maxn],ret[51][maxn],mid[51][maxn],q[51][maxn]; void upd(int &x,int y){x=(0LL+x+y)%mod;} void getwn(int len) { wn[0]=1; wn[1]=qpow(G,(mod-1)/len); int i; rep(i,2,len) wn[i]=(ll)wn[i-1]*wn[1]%mod; } void NTT(int *a,int len,int on){ int i; ll ni=qpow(len,mod-2); rep(i,0,len-1){ int p=0; double ce=log(len)/log(2); for(int j=0,tp=i;j<ce;j++,tp/=2) p=(p<<1)+(tp%2); tmp[p]=a[i]; } for (int m=2;m<=len;m*=2){ int half=m/2,bei=len/m; rep(i,0,half-1){ ll wi=on>0?wn[i*bei]:wn[len-i*bei]; for(int j=i;j<len;j+=m){ ll u=tmp[j],v=(ll)tmp[j+half]*wi%mod; tmp[j]=(u+v)%mod; tmp[j+half]=(u-v+mod)%mod; } } } if (on==-1) rep(i,0,len-1) tmp[i]=tmp[i]*ni%mod; rep(i,0,len-1) a[i]=tmp[i]; } int main(){ int i,j; scanf("%d%d%d",&n,&p,&m); int len=1; m++; while(len<m*2)len<<=1; m--; getwn(len); ret[0][0]=1; int up=10%p; for(i=min(9,m);~i;i--)q[i%p][i]++; for(;n;n>>=1) { if(n&1) { memset(mid,0,sizeof(mid)); rep(i,0,p-1)NTT(ret[i],len,1),NTT(q[i],len,1); rep(i,0,p-1)rep(j,0,p-1)rep(k,0,len-1) { upd(mid[(i*up+j)%p][k],(ll)ret[i][k]*q[j][k]%mod); } rep(i,0,p-1)NTT(mid[i],len,-1),NTT(q[i],len,-1); memcpy(ret,mid,sizeof(ret)); rep(i,0,p-1)rep(j,m+1,len-1)ret[i][j]=0; } memset(mid,0,sizeof(mid)); rep(i,0,p-1)NTT(q[i],len,1); rep(i,0,p-1)rep(j,0,p-1)rep(k,0,len-1) { upd(mid[(i*up+j)%p][k],(ll)q[i][k]*q[j][k]%mod); } rep(i,0,p-1)NTT(mid[i],len,-1); memcpy(q,mid,sizeof(q)); rep(i,0,p-1)rep(j,m+1,len-1)q[i][j]=0; up=(ll)up*up%p; } rep(i,0,m) { printf("%d%c",ret[0][i],i==m?'\n':' '); if(i+1<=m)upd(ret[0][i+1],ret[0][i]); } return 0; }