这题口胡起来真简单。。。
先来考虑
∑mi=1xi≤n,xi≥0
的整数
{xi}
个数
这可以用插板法证明是
(n+mm)
对于原问题,一个经典做法是容斥,每次枚举一些数一定超过上界,其他数任意。所以答案可以表示为:
∑S∈U(−1)|S|(n+c∗|S|−|S|−∑x∈sbx+mm)
(这里要满足
n+c∗|S|−|S|−∑x∈sbx>0
)
可以发现,
(n+mm)
是一个关于
n
的
所以可以先枚举
|S|
,然后就能将
n+c∗|S|−|S|
算出来。预处理一个
dp(i,j,k)
表示1~i位选了j个他们的k次方的和是多少。然后考虑枚举第一个不同的数位是哪个,然后算一下。。。
复杂度
O(m4)
#include<bits/stdc++.h>
#define maxn 60
#define mod 998244353
using namespace std;
int C[maxn][maxn],ans,dp[maxn][maxn][maxn],m,b,c;
int A[maxn],im,B[maxn],n,Q[maxn],S,len,R[maxn];
char s[maxn*maxn];
int qpow(int a,int b){
int ans=1,tmp=a;
for(;b;b>>=1,tmp=1ll*tmp*tmp%mod)
if(b&1)ans=1ll*ans*tmp%mod;
return ans;
}
int getdiv(int b){
int sum=0;
for(int i=n;i>=1;--i){
long long x=10ll*sum+s[i];
s[i]=x/b,sum=x%b;
}
while(n&&!s[n])n--;
return sum;
}
bool add(int c){
int ret=c;
for(int i=1;i<=len&&ret;++i)
if(Q[i]+ret<0)Q[i]=b+Q[i]+ret,ret=-1;
else if(Q[i]+ret>=b)Q[i]=Q[i]+ret-b,ret=1;
else Q[i]=Q[i]+ret,ret=0;
while(len&&!Q[len])len--;
if(ret<0)return false;
return true;
}
void print(int ans){
static int cas=0;
printf("Case #%d: %d\n",++cas,ans);
}
void cal(int i,int j,int nz){
if(i<0)return ;
for(int k=0;k<=m;++k)
for(int nk=0,l=1;nk<=k&&l;++nk,l=1ll*l*nz%mod)if(dp[j][i][k-nk])
R[k]=(R[k]+1ll*dp[j][i][k-nk]*l%mod*C[k][nk]%mod*(k-nk&1?mod-1:1))%mod;
// printf("[%d,%d:%d]\n",nz,k,nk);
}
int main(){
while(scanf("%d%d%d",&m,&b,&c)==3){
scanf("%s",s+1),n=strlen(s+1),S=len=0;
for(int i=1;i<=n;++i)s[i]-='0';
reverse(s+1,s+n+1);
if(n==1&&s[1]==0){
print(0);
continue;
}
for(int i=1,ret=1;i<=n&&ret;++i)
if(s[i]-ret<0)s[i]=10+s[i]-ret,ret=1;
else s[i]-=ret,ret=0;
while(n&&!s[n])n--;
while(n)Q[++len]=getdiv(b);
for(int i=0;i<=m;++i)
for(int j=*C[i]=1;j<=i;++j)
C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;
im=1,B[0]=1;
for(int i=1;i<=m;++i)B[i]=1ll*B[i-1]*b%mod;
for(int i=1;i<=m;++i)im=1ll*im*i%mod;
im=qpow(im,mod-2);
memset(A,0,sizeof(A)),A[0]=1;
for(int i=0;i<=m-1;++i)
for(int j=i;j>=0;--j)
A[j+1]=(A[j+1]+A[j])%mod,A[j]=1ll*(m-i)*A[j]%mod;
for(int i=1,j=1;i<=len;++i,j=1ll*j*b%mod)S=(S+1ll*j*Q[i])%mod;
// for(int i=0;i<=m;++i)printf("[%d]",A[i]);puts("");
for(int i=0;i<=m;++i)A[i]=1ll*A[i]*im%mod;
memset(dp,0,sizeof(dp));
dp[0][0][0]=1;
for(int i=1;i<=len&&i<=m;++i)
for(int j=0;j<=i;++j)
for(int k=0;k<=m;++k){
dp[i][j][k]=dp[i-1][j][k];
if(j)for(int l=0,z=1;l<=k;++l,z=1ll*z*B[i]%mod)if(dp[i-1][j-1][k-l])
dp[i][j][k]=(dp[i][j][k]+1ll*dp[i-1][j-1][k-l]*C[k][l]%mod*z)%mod;
// if(dp[i][j][k])printf("dp[%d][%d][%d]=%d\n",i,j,k,dp[i][j][k]);
}
int ans=0;
for(int i=0,dx=1;i<=m;++i,dx=mod-dx){
int sum=0;
if(!len){
ans=(ans+1ll*C[m][i]*dx)%mod;
if(!add(c-1))break;
continue;
}
for(int j=0;j<=m;++j)R[j]=0;
int x=0,z=0;
// printf("{%d}",len);
for(int j=len;j>=2&&x<=i;--j)if(Q[j]){
// printf("call:[%d,%d,%d]",i-x,j-2,S-z);
cal(i-x,j-2,(S-z+mod)%mod);
x++,z=(z+B[j-1])%mod;
if(j-1>m)goto nxt;
if(Q[j]>1){
cal(i-x,j-2,(S-z+mod)%mod);
goto nxt;
}
}
if(i==x){
for(int j=0,l=1;j<=m&&l;j++,l=1ll*l*(S-z+mod)%mod)
R[j]=(R[j]+l)%mod;
}
nxt:;
// printf("\n[i=%d]\n",i);
// for(int j=0;j<=m;++j)printf("[%d]",R[j]);
for(int j=0;j<=m;++j)sum=(sum+1ll*R[j]*A[j])%mod;
ans=(ans+1ll*sum*dx)%mod;
if(!add(c-1))break;
S=(1ll*S+c-1+mod)%mod;
}
print(ans);
}
}