【题目链接】
【思路要点】
- 问题要求所有数互不相同,不妨规定\(A_1<A_2<...<A_N\)。
- 按照数位DP的思路,从高位向低位DP,记录一个\(N\)位的二进制状态,其中第\(i(i<N)\)位表示到目前为止\(A_i\)是否等于\(A_{i+1}\),第\(N\)位表示到目前为止\(A_N\)是否等于\(R\)。
- 转移时枚举当前位所有数的取值。
- 这个DP在0和1处的转移分别可以看做一个转移矩阵。用矩阵乘法优化它即可。
- 时间复杂度\(O(2^{3N}*(LogK+|S|))\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int P = 1e9 + 7; const int MAXN = 100005; const int MAXM = 105; const int MAXS = 128; const int MAXLOG = 18; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } char s[MAXM]; int n, m, k, bit[MAXM], cnt[MAXS]; int one[MAXS][MAXS], zero[MAXS][MAXS]; int matrix[MAXLOG][MAXS][MAXS], natrix[MAXS][MAXS]; int main() { read(n), read(k); scanf("%s", s + 1); m = strlen(s + 1); int goal = (1 << n) - 1; for (int i = 1; i <= n; i++) bit[i] = 1 << (i - 1); for (int i = 1; i <= goal; i++) cnt[i] = cnt[i - (i & -i)] + 1; for (int s = 0; s <= goal; s++) { for (int t = 0; t <= goal; t++) { if (cnt[t] & 1) continue; bool valid = true; for (int i = 1; i <= n - 1; i++) if ((t & bit[i]) != 0 && (t & bit[i + 1]) == 0 && (s & bit[i]) != 0) valid = false; if (valid) { int res = s; for (int i = 1; i <= n - 1; i++) if ((res & bit[i]) != 0 && (t & bit[i]) == 0 && (t & bit[i + 1]) != 0) res ^= bit[i]; if ((res & bit[n]) != 0 && (t & bit[n]) == 0) res ^= bit[n]; one[s][res]++; } } for (int t = 0; t <= goal; t++) { if (cnt[t] & 1) continue; bool valid = true; for (int i = 1; i <= n - 1; i++) if ((t & bit[i]) != 0 && (t & bit[i + 1]) == 0 && (s & bit[i]) != 0) valid = false; if (valid && ((t & bit[n]) == 0 || (s & bit[n]) == 0)) { int res = s; for (int i = 1; i <= n - 1; i++) if ((res & bit[i]) != 0 && (t & bit[i]) == 0 && (t & bit[i + 1]) != 0) res ^= bit[i]; zero[s][res]++; } } } for (int s = 0; s <= goal; s++) matrix[0][s][s] = 1; for (int p = 1; p <= m; p++) { memset(natrix, 0, sizeof(natrix)); if (s[p] == '1') { for (int i = 0; i <= goal; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) natrix[i][j] = (natrix[i][j] + 1ll * matrix[0][i][k] * one[k][j]) % P; } else { for (int i = 0; i <= goal; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) natrix[i][j] = (natrix[i][j] + 1ll * matrix[0][i][k] * zero[k][j]) % P; } memcpy(matrix[0], natrix, sizeof(natrix)); } for (int p = 1; p < MAXLOG; p++) for (int i = 0; i <= goal; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) matrix[p][i][j] = (matrix[p][i][j] + 1ll * matrix[p - 1][i][k] * matrix[p - 1][k][j]) % P; static int now[1][MAXS], tmp[1][MAXS]; now[0][goal]++; k--; for (int p = MAXLOG - 1; p >= 0; p--) { int tnp = 1 << p; if (tnp & k) { memset(tmp, 0, sizeof(tmp)); for (int i = 0; i <= 0; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) tmp[i][j] = (tmp[i][j] + 1ll * now[i][k] * matrix[p][k][j]) % P; memcpy(now, tmp, sizeof(tmp)); } } for (int i = m; i >= 1; i--) if (s[i] == '1') { s[i] = '0'; break; } else s[i] = '1'; for (int p = 1; p <= m; p++) { if (s[p] == '0') { memset(tmp, 0, sizeof(tmp)); for (int i = 0; i <= 0; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) tmp[i][j] = (tmp[i][j] + 1ll * now[i][k] * zero[k][j]) % P; memcpy(now, tmp, sizeof(tmp)); } else { memset(tmp, 0, sizeof(tmp)); for (int i = 0; i <= 0; i++) for (int j = 0; j <= goal; j++) for (int k = 0; k <= goal; k++) tmp[i][j] = (tmp[i][j] + 1ll * now[i][k] * one[k][j]) % P; memcpy(now, tmp, sizeof(tmp)); } } int ans = now[0][0] + now[0][bit[n]]; writeln(ans % P); return 0; }