【题目链接】
【思路要点】
- 将每个质因数是否在两个集合中出现过状压,容易得到一个\(O(3^{Cntprime})\)的状态表示,其中\(Cntprime\)为\(N\)以内质数的个数,但这显然无法通过。
- 我们发现任何正整数\(X\),其大于\(\sqrt{X}\)的质因数至多只有一个。
- 我们对小于\(\sqrt{N}\)的质数进行状压,将存在大于等于\(\sqrt{N}\)的质因数的\(X\)按照其质因数分类,其余的数各自为一类,需要满足同类的数至多出现在一个集合里。
- 分组后DP即可解决,时间复杂度\(O(N*Cnt*3^{Cnt})\),其中\(Cnt\)为小于\(\sqrt{N}\)的质数个数,当\(N=500\)时,\(Cnt=8\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 505; const int MAXS = 10005; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct info {int type, mask; }; int n, P, m; int tot, prime[MAXN], bit[MAXN]; int dp[2][3][MAXS]; void update(int &x, int y) {x = (x + y) % P; } int main() { read(n), read(P); if (n <= 3) { if (n == 2) printf("%d\n", 3 % P); else printf("%d\n", 9 % P); return 0; } for (int i = 2; i <= n; i++) { bool isprime = true; for (int j = 2; j * j <= i; j++) if (i % j == 0) { isprime = false; break; } if (isprime) { prime[++tot] = i; if (i * i <= n) m = tot; } } bit[1] = 1; for (int i = 2; i <= m; i++) bit[i] = bit[i - 1] * 3; int goal = bit[m] * 3 - 1; dp[0][0][0] = 1; for (int i = 2; i <= n; i++) { memset(dp[1], 0, sizeof(dp[1])); bool found = false; for (int j = m + 1; j <= tot; j++) if (i % prime[j] == 0) found = true; if (found) continue; for (int s = 0; s <= goal; s++) { int t1 = s; bool valid1 = true; int t2 = s; bool valid2 = true; for (int j = 1; j <= m; j++) if (i % prime[j] == 0) { if (s / bit[j] % 3 == 1) valid2 = false; if (s / bit[j] % 3 == 2) valid1 = false; if (s / bit[j] % 3 == 0) t1 += bit[j], t2 += bit[j] * 2; } update(dp[1][0][s], dp[0][0][s]); if (valid1) update(dp[1][0][t1], dp[0][0][s]); if (valid2) update(dp[1][0][t2], dp[0][0][s]); } memcpy(dp[0], dp[1], sizeof(dp[1])); } for (int i = m + 1; i <= tot; i++) { for (int val = prime[i]; val <= n; val += prime[i]) { memset(dp[1], 0, sizeof(dp[1])); for (int s = 0; s <= goal; s++) { int t1 = s; bool valid1 = true; int t2 = s; bool valid2 = true; for (int j = 1; j <= m; j++) if (val % prime[j] == 0) { if (s / bit[j] % 3 == 1) valid2 = false; if (s / bit[j] % 3 == 2) valid1 = false; if (s / bit[j] % 3 == 0) t1 += bit[j], t2 += bit[j] * 2; } update(dp[1][0][s], dp[0][0][s]); update(dp[1][1][s], dp[0][1][s]); update(dp[1][2][s], dp[0][2][s]); if (valid1) { update(dp[1][1][t1], dp[0][0][s]); update(dp[1][1][t1], dp[0][1][s]); } if (valid2) { update(dp[1][2][t2], dp[0][0][s]); update(dp[1][2][t2], dp[0][2][s]); } } memcpy(dp[0], dp[1], sizeof(dp[1])); } memset(dp[1], 0, sizeof(dp[1])); for (int j = 0; j <= goal; j++) { update(dp[1][0][j], dp[0][0][j]); update(dp[1][0][j], dp[0][1][j]); update(dp[1][0][j], dp[0][2][j]); } memcpy(dp[0], dp[1], sizeof(dp[1])); } int ans = 0; for (int i = 0; i <= 2; i++) for (int j = 0; j <= goal; j++) update(ans, dp[0][i][j]); writeln(ans); return 0; }