题面
解法
将问题稍微转化一下就变得比较简单了
- 直接求似乎并没有那么好做,考虑补集转化,即最后的答案=总方案数-每一个数都不是质数的方案数
- 可以发现,总方案数和每一个数都不是质数的求法其实本质上是一样的,就暂且先只考虑每一个数都不是质数的情况怎么计算答案
- p ≤ 100 p≤100 p≤100,那么我们可以将所有非质数按照 % p \%p %p的余数分类,记 s [ i ] s[i] s[i]表示 % p = i \%p=i %p=i的数的个数
- 然后我们可以记状态 f [ i ] [ j ] f[i][j] f[i][j]表示已经填了 i i i个数,且当前所有数的和对 p p p取模为 j j j的方案数
- 转移方程就很容易写了, f [ i ] [ j ] = ∑ f [ i − 1 ] [ j ′ ] × s [ j − j ′ ] f[i][j]=\sum f[i-1][j']×s[j-j'] f[i][j]=∑f[i−1][j′]×s[j−j′]
- 然后对于每一个 j j j,能转移到它的 j ′ j' j′和 s [ j − j ′ ] s[j-j'] s[j−j′]是固定的,所以可以通过矩阵乘法来优化这个过程
- 同样,总方案也可以这么计算
- 时间复杂度: O ( p 3 log n ) O(p^3\log n) O(p3logn)
代码
#include <bits/stdc++.h>
#define LL long long
#define Mod 20170408
#define M 20000010
#define N 110
using namespace std;
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Matrix {
int a[N][N];
void Clear() {memset(a, 0, sizeof(a));}
};
int n, m, p, len, s1[N], s2[N], pr[M / 10];
bool f[M];
Matrix operator * (Matrix x, Matrix y) {
Matrix ret; ret.Clear();
for (int k = 0; k < p; k++)
for (int i = 0; i < p; i++)
for (int j = 0; j < p; j++)
ret.a[i][j] = ((LL)ret.a[i][j] + (LL)x.a[i][k] * y.a[k][j] % Mod) % Mod;
return ret;
}
Matrix Pow(Matrix x, int y) {
Matrix ret = x; y--;
while (y) {
if (y & 1) ret = ret * x;
y >>= 1; x = x * x;
}
return ret;
}
void sieve() {
len = 0;
for (int i = 2; i <= m; i++) f[i] = true;
for (int i = 2; i <= m; i++) {
if (f[i]) pr[++len] = i;
for (int j = 1; j <= len && i * pr[j] <= m; j++) {
f[i * pr[j]] = false;
if (i % pr[j] == 0) break;
}
}
}
int calc(int n, int key) {
Matrix tx; tx.Clear();
for (int i = 0; i < p; i++)
for (int j = 0; j < p; j++) {
int k = (i - j + p) % p;
tx.a[i][j] = (key == 1) ? s1[k] : s2[k];
}
tx = Pow(tx, n);
Matrix ret; ret.Clear();
ret.a[0][0] = 1; ret = tx * ret;
return ret.a[0][0];
}
main() {
read(n), read(m), read(p);
sieve();
for (int i = 1; i <= m; i++) {
s1[i % p] = (s1[i % p] + 1) % Mod;
if (!f[i]) s2[i % p] = (s2[i % p] + 1) % Mod;
}
cout << (calc(n, 1) - calc(n, 2) + Mod) % Mod << "\n";
return 0;
}