题面
题解
看到 m m m 很大,联想到矩阵快速幂。
由于对于每个初始的 x x x,它变成 1 1 1 的方法是唯一的。所以我们可以考虑从 1 1 1 倒推,这样用不同的方法倒推得到的数肯定是不同的,所以不会算重。
为了方便,我们从 0 0 0 而不是 1 1 1 开始倒推,此时原来倒推 m m m 步就变成了倒推 m + 1 m+1 m+1 步(这样会算多,但我们最后再处理多出来的这一部分)。
假设当前数为 x x x,考虑 x x x 向前倒推:若 x ≢ k − 1 ( m o d k ) x\not \equiv k-1 \pmod k x≡k−1(modk),则 x x x 可以由 x + 1 x+1 x+1 和 x k xk xk 转移得到;若 x ≡ k − 1 ( m o d k ) x\equiv k-1\pmod k x≡k−1(modk),则 x x x 只可以由 x k xk xk 转移得到。
那么设
f
i
,
j
f_{i,j}
fi,j 表示经过倒推
i
i
i 步后,有多少个数是模
k
k
k 余
j
j
j 的(即有多少个模
k
k
k 余
j
j
j 的数经过
i
i
i 步操作后变成
1
1
1),容易得到初始矩阵
G
G
G 和转移矩阵
A
A
A:(矩阵的行列从
0
0
0 开始编号)
G
=
f
0
=
[
1
0
0
⋮
0
]
A
=
[
1
1
⋯
1
1
1
0
⋯
0
0
0
1
⋯
0
0
⋮
⋮
⋱
⋮
⋮
0
0
⋯
1
0
]
k
×
k
G=f_0=\begin{bmatrix} 1\\ 0\\ 0\\ \vdots\\ 0 \end{bmatrix}\\ A=\begin{bmatrix} 1&1&\cdots&1&1\\ 1&0&\cdots&0&0\\ 0&1&\cdots&0&0\\ \vdots&\vdots&\ddots&\vdots&\vdots\\ 0&0&\cdots&1&0 \end{bmatrix}_{k\times k}\\
G=f0=⎣⎢⎢⎢⎢⎢⎡100⋮0⎦⎥⎥⎥⎥⎥⎤A=⎣⎢⎢⎢⎢⎢⎡110⋮0101⋮0⋯⋯⋯⋱⋯100⋮1100⋮0⎦⎥⎥⎥⎥⎥⎤k×k
那么我们要求的是
G
A
m
+
1
GA^{m+1}
GAm+1 的第
0
0
0 列所有数的和(即
∑
i
=
0
k
−
1
f
m
+
1
,
i
\sum\limits_{i=0}^{k-1}f_{m+1,i}
i=0∑k−1fm+1,i),即
G
A
m
+
2
GA^{m+2}
GAm+2 的第
0
0
0 列第
0
0
0 行的数。
我们先得到 A A A 的特征多项式 f ( λ ) = ∣ λ I − A ∣ = λ k − ∑ i = 0 k − 1 λ i = λ k − λ k − 1 − λ k − 2 − ⋯ − λ 0 f(\lambda)=|\lambda I-A|=\lambda^k-\sum\limits_{i=0}^{k-1}\lambda^i=\lambda^k-\lambda^{k-1}-\lambda^{k-2}-\cdots-\lambda^0 f(λ)=∣λI−A∣=λk−i=0∑k−1λi=λk−λk−1−λk−2−⋯−λ0。
由某个 C 开头的定理知 f ( A ) = 0 f(A)=0 f(A)=0。
由于我们要求 A m + 2 A^{m+2} Am+2,所以我们不妨将 A m + 2 A^{m+2} Am+2 一直减去 f ( A ) f(A) f(A) 直到次数恰好小于 k k k 为止(也就是取模)。
不妨令 g ( λ ) = λ m + 2 m o d f ( λ ) g(\lambda)=\lambda^{m+2}\bmod f(\lambda) g(λ)=λm+2modf(λ)(这个可以用边快速幂边取模的方法求出。注意 f ( λ ) f(\lambda) f(λ) 的形式很特殊,可以 O ( k ) O(k) O(k) 简单取模)。将 g ( λ ) g(\lambda) g(λ) 展开,设 g ( λ ) = ∑ i = 0 k − 1 a i λ i g(\lambda)=\sum\limits_{i=0}^{k-1}a_i\lambda^i g(λ)=i=0∑k−1aiλi。
那么 A m + 2 = A m + 2 m o d f ( A ) = g ( A ) = ∑ i = 0 k − 1 a i A i A^{m+2}=A^{m+2}\bmod f(A)=g(A)=\sum\limits_{i=0}^{k-1}a_iA^i Am+2=Am+2modf(A)=g(A)=i=0∑k−1aiAi。
那么我们要求的就是:
( G A m + 2 ) 0 , 0 = ∑ i = 0 k − 1 a i ( G A i ) 0 , 0 \left(GA^{m+2}\right)_{0,0}=\sum_{i=0}^{k-1}a_i\left(GA^{i}\right)_{0,0} (GAm+2)0,0=i=0∑k−1ai(GAi)0,0
注意到 ( G A i ) 0 , 0 \left(GA^i\right)_{0,0} (GAi)0,0 就是 f i , 0 f_{i,0} fi,0,所以原式即为: ( G A m + 2 ) 0 , 0 = ∑ i = 0 k − 1 a i f i , 0 \left(GA^{m+2}\right)_{0,0}=\sum\limits_{i=0}^{k-1}a_if_{i,0} (GAm+2)0,0=i=0∑k−1aifi,0。
所以我们只需要知道 0 ≤ i < k 0\leq i<k 0≤i<k 的 f i , 0 f_{i,0} fi,0 即可。
但是暴力预处理是 O ( k 2 ) O(k^2) O(k2) 的,还是太大了。
但容易发现在 0 ≤ i < k 0\leq i <k 0≤i<k 的情况下, f i , 0 = 2 i f_{i,0}=2^i fi,0=2i。
大概是因为在操作步数 < k <k <k 的情况下,不可能出现当前数模 k k k 余 k − 1 k-1 k−1 向外转移的情况,所以每个数都有两种向外的转移方法。
那么就能在 O ( k log k log m ) O(k\log k\log m) O(klogklogm) 的时间内求出 ( G A m + 2 ) 0 , 0 \left(GA^{m+2}\right)_{0,0} (GAm+2)0,0 了。
一开始我们说过,这样会算多,因为我们是从 0 0 0 开始倒推的。
所以我们需要让倒推的第一步强制选 + 1 +1 +1(即强制让正推的最后一步是 1 − 1 → 0 1-1\to 0 1−1→0)。(注意这里保证了 k > 1 k>1 k>1,如果 k = 1 k=1 k=1 的话请在程序开始时特判输出 1 1 1)
观察
f
1
f_{1}
f1 所对应的矩阵:
[
1
1
0
0
⋮
0
]
\begin{bmatrix} 1\\ 1\\ 0\\ 0\\ \vdots\\ 0 \end{bmatrix}
⎣⎢⎢⎢⎢⎢⎢⎢⎡1100⋮0⎦⎥⎥⎥⎥⎥⎥⎥⎤
理论上来说,我们需要强制让除
f
1
,
1
=
1
f_{1,1}=1
f1,1=1 之外的其他位置都是
0
0
0。
所以我们要减去 f 1 , 0 = 1 f_{1,0}=1 f1,0=1 时这个 1 1 1 对之后的矩乘的贡献。
也就是说我们要减去 G A m GA^m GAm 第 0 0 0 列所有数的和,即 ( G A m + 1 ) 0 , 0 \left(GA^{m+1}\right)_{0,0} (GAm+1)0,0。
所以真正的答案应该是 ( G A m + 2 ) 0 , 0 − ( G A m + 1 ) 0 , 0 \left(GA^{m+2}\right)_{0,0}-\left(GA^{m+1}\right)_{0,0} (GAm+2)0,0−(GAm+1)0,0。
总时间复杂度 O ( k log k log m ) O(k\log k\log m) O(klogklogm)。
关于取模的问题:
由于 m o d ≤ 300 mod\leq 300 mod≤300 很小,所以在快速幂中:多项式乘法时用 NTT 在模一个很大的模数 M M M(如 M = 998244353 M=998244353 M=998244353)意义下进行,目的是让 NTT 中的模不起作用(即乘出来的系数不可能大于等于 M M M,模了 M M M 和没模一样);多项式取模时再用 m o d mod mod 模,因为此时的多项式取模可以不用 NTT, O ( k ) O(k) O(k) 做。
代码如下:
#include<bits/stdc++.h>
#define LN 16
#define N 10010
#define ll long long
using namespace std;
const int M=998244353;
namespace modular
{
inline int add(const int x,const int y,const int mod=M){return x+y>=mod?x+y-mod:x+y;}
inline int dec(const int x,const int y,const int mod=M){return x-y<0?x-y+mod:x-y;}
inline int mul(const int x,const int y,const int mod=M){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int poww(int a,int b,const int mod=M)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
int k,mod;
int limit,rev[N<<2],w[LN][N<<2][2];
int a[N<<2];
ll m;
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(M-1)/len);
int ign=poww(gn,M-2);
int g=1,ig=1;
for(int j=0;j<mid;g=mul(g,gn),ig=mul(ig,ign),j++)
w[bit][j][0]=g,w[bit][j][1]=ig;
}
}
void NTT(int *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,M-2);
for(int i=0;i<limit;i++)
a[i]=mul(a[i],tmp);
}
}
void modmul(int *f,int *g)
{
static int A[N<<2],B[N<<2],sum[N<<3];
for(int i=0;i<limit;i++) A[i]=f[i],B[i]=g[i];
NTT(A,limit,1),NTT(B,limit,1);
for(int i=0;i<limit;i++) A[i]=mul(A[i],B[i]);
NTT(A,limit,-1);
for(int i=0;i<limit;i++) A[i]%=mod;
for(int i=limit-1;i>=k;i--)
sum[i]=add(sum[i+1],add(A[i],dec(sum[i+1],sum[i+k+1],mod),mod),mod);
for(int i=k-1;i>=0;i--)
f[i]=add(A[i],dec(sum[k],sum[i+k+1],mod),mod);
for(int i=0;i<limit;i++) A[i]=B[i]=sum[i]=0;
}
void work(int *ans,ll b)
{
static int now[N<<2];
ans[0]=1,now[1]=1;
while(b)
{
if(b&1ll) modmul(ans,now);
modmul(now,now);
b>>=1ll;
}
}
int main()
{
scanf("%d%lld%d",&k,&m,&mod);
if(k==1)
{
puts("1");
return 0;
}
limit=1;
while(limit<((k+1)<<1)) limit<<=1;
init(limit);
work(a,m+1);
int ans1=0;
for(int i=0,tmp=1;i<k;i++)
{
ans1=add(ans1,mul(a[i],tmp,mod),mod);
if(i) tmp=add(tmp,tmp,mod);
}
for(int i=k;i>=1;i--) a[i]=a[i-1];
a[0]=0;
for(int i=0;i<k;i++) a[i]=add(a[i],a[k],mod);
int ans2=0;
for(int i=0,tmp=1;i<k;i++)
{
ans2=add(ans2,mul(a[i],tmp,mod),mod);
if(i) tmp=add(tmp,tmp,mod);
}
printf("%d\n",dec(ans2,ans1,mod));
return 0;
}
/*
2 4 31
*/