题面
给出
n
,
p
,
k
,
r
(
1
≤
n
≤
1
0
9
,
0
≤
r
<
k
≤
50
,
2
≤
p
≤
2
30
−
1
)
n,p,k,r~~(1\leq n\leq10^9,0\leq r<k\leq 50,2\leq p\leq 2^{30}-1)
n,p,k,r (1≤n≤109,0≤r<k≤50,2≤p≤230−1),求
∑
i
=
0
∞
C
n
k
i
k
+
r
m
o
d
p
\sum_{i=0}^{\infty}C_{nk}^{ik+r} \mod p
i=0∑∞Cnkik+rmodp
题解
可以利用生成函数+二项式定理:
∑
i
=
0
∞
C
n
k
i
k
+
r
=
∑
i
=
0
∞
[
[
i
k
+
r
]
]
(
x
+
1
)
n
k
=
∑
i
=
0
∞
[
[
i
]
]
(
x
+
1
)
n
k
⋅
[
i
%
k
=
=
r
]
\sum_{i=0}^{\infty}C_{nk}^{ik+r} =\sum_{i=0}^{\infty}[[ik+r]](x+1)^{nk}\\ = \sum_{i=0}^{\infty}[[i]](x+1)^{nk}\cdot[i\%k==r]
i=0∑∞Cnkik+r=i=0∑∞[[ik+r]](x+1)nk=i=0∑∞[[i]](x+1)nk⋅[i%k==r]
这么转化用意已经很明显了,我们把它看成是一个环状多项式。把卷积定义为 C ( i + j ) m o d k : = ∑ A i ⋅ B j C_{(i+j)\!\!\!\mod k}:=\sum A_{i}\cdot B_{j} C(i+j)modk:=∑Ai⋅Bj ,那么多项式就只有最多 k k k 项,暴力卷积+快速幂就行。
也可以利用递推公式,这个是最好想的:
C
′
(
n
,
m
)
=
∑
i
C
(
n
,
i
k
+
m
)
,
C
′
(
i
,
j
)
+
C
′
(
i
,
(
j
+
1
)
m
o
d
k
)
→
C
′
(
i
+
1
,
(
j
+
1
)
m
o
d
k
)
C'(n,m)=\sum_{i} C(n,ik+m),\\ C'(i,j)+C'(i,(j+1)\!\!\!\!\mod k)\rightarrow C'(i+1,(j+1)\!\!\!\!\mod k)
C′(n,m)=i∑C(n,ik+m),C′(i,j)+C′(i,(j+1)modk)→C′(i+1,(j+1)modk)
还是把状态合并的思路,所有下标模 k k k 同余的项放到一起,递推式子仍然成立,那么就可以用矩阵加速。这个矩阵很特别,可以用多项式优化。
CODE
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 405
#define ENDL putchar('\n')
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
void putuint(int x) {
if(!x) return ;
putuint(x/10);putchar(x%10+'0');
}
void putint(int x) {if(x==0)putchar('0');if(x<0)putchar('-'),x=-x;putuint(x);}
int n,m,i,j,s,o,k;
int MOD = 1;
struct poly{
int n;
int s[50];
poly(){n = 0;}
void set(int N) {
n = N;for(int i = 0;i < n;i ++) s[i] = 0;
}
};
poly operator * (poly a,poly b) {
poly c; c.set(a.n);
for(int i = 0;i < a.n;i ++) {
for(int j = 0;j < b.n;j ++) {
(c.s[(i+j)%c.n] += a.s[i]*1ll*b.s[j] % MOD) %= MOD;
}
}return c;
}
poly qkpow(poly a,LL b) {
poly res; res.set(a.n);res.s[0] = 1;
while(b > 0) {
if(b & 1) res = res * a;
a = a * a; b >>= 1;
}return res;
}
int main() {
LL N = read();MOD = read();
k = read(); int r = read();
poly A,B;B.set(k);
B.s[0] = 1;B.s[1%k] += 1;
A = qkpow(B,N*k);
printf("%d\n",A.s[r]);
return 0;
}