题目
Luogu
题目大意:
k
k
k 种珍珠,每种珍珠都要用上,问能做出长度[1,2,…,N]的首饰的方案数,答案模
1234567891
1234567891
1234567891
T
<
=
10
,
1
<
=
N
<
=
1000000000
,
0
<
=
K
<
=
30
T <=10, 1<= N<= 1000000000, 0<= K<= 30
T<=10,1<=N<=1000000000,0<=K<=30
思路
我们定义:
f
[
i
]
[
j
]
:
前
i
个
位
置
用
j
种
珍
珠
的
方
案
数
f[i][j]:前i个位置用j种珍珠的方案数
f[i][j]:前i个位置用j种珍珠的方案数
于是有转移方程式:
f
[
i
]
[
j
]
=
f
[
i
−
1
]
[
j
]
∗
j
+
f
[
i
−
1
]
[
j
−
1
]
∗
(
k
−
j
)
f[i][j]=f[i-1][j]*j+f[i-1][j-1]*(k-j)
f[i][j]=f[i−1][j]∗j+f[i−1][j−1]∗(k−j)
其中
f
[
0
]
[
0
]
=
1
f[0][0]=1
f[0][0]=1
那么
A
n
s
=
∑
i
=
1
n
f
[
i
]
[
k
]
Ans=\sum_{i=1}^{n}f[i][k]
Ans=∑i=1nf[i][k]
发现这样每次
i
i
i 只会+1,并且转移时系数和
i
i
i 没太大关系,于是可以考虑矩阵加速
我们记
a
n
s
i
=
∑
j
=
1
i
f
[
j
]
[
k
]
ans_i=\sum_{j=1}^{i}f[j][k]
ansi=∑j=1if[j][k]
那么可以得到转换矩阵:
(
f
[
i
−
1
]
[
0
]
f
[
i
−
1
]
[
1
]
⋯
f
[
i
−
1
]
[
k
]
a
n
s
i
−
2
)
∗
(
0
k
0
⋯
0
0
0
0
1
k
−
1
⋯
0
0
0
0
0
2
⋯
0
0
0
⋮
⋮
⋮
⋱
⋮
⋮
⋮
0
0
0
⋯
k
−
1
1
0
0
0
0
⋯
0
k
1
0
0
0
⋯
0
0
1
)
=
(
f
[
i
]
[
0
]
f
[
i
]
[
1
]
⋯
f
[
i
]
[
k
]
a
n
s
i
−
1
)
\left( \begin{matrix} f[i-1][0]&f[i-1][1]&\cdots&f[i-1][k]&ans_{i-2} \end{matrix} \right) * \left( \begin{matrix} 0&k&0&\cdots&0&0&0\\ 0&1&k-1&\cdots&0&0&0\\ 0&0&2&\cdots&0&0&0\\ \vdots&\vdots&\vdots&\ddots&\vdots&\vdots&\vdots\\ 0&0&0&\cdots&k-1&1&0\\ 0&0&0&\cdots&0&k&1\\ 0&0&0&\cdots&0&0&1\\ \end{matrix} \right) \quad= \left( \begin{matrix} f[i][0]&f[i][1]&\cdots&f[i][k]&ans_{i-1} \end{matrix} \right)
(f[i−1][0]f[i−1][1]⋯f[i−1][k]ansi−2)∗⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛000⋮000k10⋮0000k−12⋮000⋯⋯⋯⋱⋯⋯⋯000⋮k−100000⋮1k0000⋮011⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞=(f[i][0]f[i][1]⋯f[i][k]ansi−1)
那么记转换矩阵为
A
A
A ,那么
A
n
+
1
A^{n+1}
An+1 的右上角即为答案
时间复杂度:
O
(
k
3
l
o
g
n
)
O(k^3log_n)
O(k3logn)
代码
#include<set>
#include<map>
#include<stack>
#include<cmath>
#include<cstdio>
#include<queue>
#include<vector>
#include<climits>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
LL read(){
LL f=1,x=0;char s=getchar();
while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
return x*f;
}
#define MAXN 32
#define INF 0x3f3f3f3f
#define Mod 1234567891
struct Matrix{
int r,c;
LL m[MAXN+5][MAXN+5];
Matrix(){}
Matrix(int R,int C){r=R,c=C,memset(m,0,sizeof(m));}
Matrix operator * (Matrix a){
Matrix b(r,a.c);
for(int i=1;i<=r;i++)
for(int j=1;j<=a.c;j++)
for(int k=1;k<=c;k++)
b.m[i][j]=(b.m[i][j]+m[i][k]*a.m[k][j]%Mod)%Mod;
return b;
}
void print(){
for(int i=1;i<=r;i++)
for(int j=1;j<=c;j++)
printf("%lld",m[i][j]),putchar(j==c?'\n':' ');
return ;
}
};
Matrix Pow(Matrix x,int y){
Matrix ret(x.r,x.c);
for(int i=1;i<=x.r;i++)
ret.m[i][i]=1;
while(y){
if(y&1) ret=ret*x;
x=x*x;
y>>=1;
}
return ret;
}
Matrix A,B;
int main(){
int T=read();
while(T--){
int n=read(),k=read();
A=B=Matrix();
A.r=A.c=k+2;
for(int i=2;i<=k+1;i++)
A.m[i][i]=i-1,A.m[i-1][i]=k-i+2;
A.m[k+1][k+2]=1,A.m[k+2][k+2]=1;
A=Pow(A,n+1);//ans[-1]...ans[n]
B.r=1,B.c=k+2;
B.m[1][1]=1;
B=B*A;
printf("%lld\n",B.m[1][k+2]);
}
return 0;
}