貌似是洛谷的题,但是没找到
题目描述
题解
首先考虑怎么朴素地求函数 f m f_m fm,我们可以发现它就是把 1 , m , m 2 , . . . 1,m,m^2,... 1,m,m2,... 挨个拿进去做了一个完全背包。
若把 DP 数组看作一个生成函数,那么用某个数做一遍完全背包相当于乘上一个多项式。所以求 g m k g_m^k gmk 只需要把每个数重复做 k k k 次完全背包即可。
复杂度 O ( n k log m n ) O(nk\log_mn) O(nklogmn),这样就可以拿 43 分了。接下来我准备往生成函数方面想,但我是GF的屎,我失败了。
于是不妨直接看看上面的暴力怎么优化:
观察 DP 过程,以 k = 1 k=1 k=1 为例,先用1做完全背包,相当于把只有0处为1其它全0的 DP 数组做了个前缀和。此时 DP 值关于 n n n 是零次的,前缀和是一次的;
后面每次做完全背包的数都是
m
m
m 的倍数,所以每
m
m
m 个连续的 DP 值的变化是类似的。不妨设
d
p
′
[
i
]
=
∑
j
=
i
m
(
i
+
1
)
m
−
1
d
p
[
j
]
dp'[i]=\sum_{j=im}^{(i+1)m-1}dp[j]
dp′[i]=∑j=im(i+1)m−1dp[j],那么当我们用
m
m
m 做完全背包时,相当于把
d
p
′
dp'
dp′ 数组做了个前缀和!
由于我们最后要求的是
n
n
n 处的前缀和,所以不妨改一个定义:
d
p
′
[
i
]
=
∑
j
=
i
m
+
(
n
%
m
)
(
i
+
1
)
m
+
(
n
%
m
)
−
1
d
p
[
j
]
dp'[i]=\sum_{j=im+(n\%m)}^{(i+1)m+(n\%m)-1}dp[j]
dp′[i]=j=im+(n%m)∑(i+1)m+(n%m)−1dp[j]这样就能保证最后可以通过
d
p
′
dp'
dp′ 求到
n
n
n 处的前缀和。
再看这个
d
p
′
dp'
dp′ 数组,做完第二遍完全背包时,它关于
n
n
n 是一次的,前缀和是二次的。
后面的过程类似。
我们发现由于每次做背包的数呈倍数递增,当我们用 x x x 做完完全背包后,若把 DP 数组连续 x x x 个值分段加起来打包成新数组,那么这个新数组是一个次数较小的函数。
总共会做不超过 k log m n k\log_mn klogmn 次背包,函数的次数每次+1,最大也是 O ( k log m n ) O(k\log_mn) O(klogmn) 数量级。所以我们可以只维护前 k log m n k\log_mn klogmn 个 DP 值,当打包新数组的时候需要求另外 k log m n k\log_mn klogmn 个地方的点值(当然是前缀和的点值),直接 O ( k log m n ) O(k\log_mn) O(klogmn) 拉插即可。
总复杂度 O ( k 2 log m 3 n ) O(k^2\log_m^3n) O(k2logm3n),我是直接跑满了。
代码
真的非常简单
#include<bits/stdc++.h>//JZM yyds!!
#define ll long long
#define lll __int128
#define uns unsigned
#define fi first
#define se second
#define IF (it->fi)
#define IS (it->se)
#define END putchar('\n')
#define lowbit(x) ((x)&-(x))
#define inline jzmyyds
using namespace std;
const int MAXN=23333;
const ll INF=1e17;
ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt>0)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
const ll MOD=1e9+7;
ll ksm(ll a,ll b,ll mo){
ll res=1;
for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
return res;
}
ll fac[MAXN],inv[MAXN];
int init(int n){
fac[0]=inv[0]=1;
for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%MOD;
inv[n]=ksm(fac[n],MOD-2,MOD);
for(int i=n-1;i>0;i--)inv[i]=inv[i+1]*(i+1)%MOD;
return 114514;
}
const int cbddl=init(23333-233),m=1300;
ll n,x,f[MAXN],cg[MAXN],g[MAXN];
int k;
ll Lag(ll x){
ll pf=1,s=0;
for(int i=0;i<=m;i++){
ll c1=f[i]*cg[i]%MOD,c2=(x-i+MOD)%MOD;
s=(s*c2+pf*c1)%MOD,(pf*=c2)%=MOD;
}return s;
}
int main()
{
freopen("partition.in","r",stdin);
freopen("partition.out","w",stdout);
n=read(),x=read(),k=read();
for(int i=0;i<=m;i++){
cg[i]=inv[i]*inv[m-i]%MOD,f[i]=1;
if((m^i)&1)cg[i]=MOD-cg[i];
}
while(n){
for(int D=k;D--;)for(int i=1;i<=m;i++){
f[i]+=f[i-1];
if(f[i]>=MOD)f[i]-=MOD;
}
ll a=n%x;n/=x;
for(int i=0;i<=m;i++,a+=x)g[i]=Lag(a);
for(int i=0;i<=m;i++)f[i]=g[i];
}
print(f[0]);
return 0;
}