题意
求有多少长度等于 N ( N ≤ 400 ) N(N\leq 400) N(N≤400)的 ∑ = { 0 , 1 } \sum = \{0,1\} ∑={0,1}字符串里面不包含长度大于等于 k ( k ≤ 10 ) k(k \leq 10) k(k≤10)的回文子串?
解题思路
这题一开始没啥想法,但是看到
k
k
k比较小想到了可以用状态压缩,不过不确定10个状态能不能保证不存在回文子串。看了题解以后发现最多11个状态就可以保证没有回文子串了。
结论:所有大于等于
k
k
k的回文子串一定包含一个长度为小于
k
k
k的回文子串。所以我们只需要判断当前串的前面
k
k
k和
k
+
1
k + 1
k+1(偶数)个子串不构成回文串就行了。
证明:当
k
k
k为奇数的时候,要凑成长度为
k
k
k的回文串需要一个前缀为
k
k
k的奇数个数回文串,或者一个长度为
k
−
1
k-1
k−1的偶数回文串两边各匹配一个相同的值来凑成,这时候长度就是
k
+
1
k+1
k+1。当
k
k
k为偶数的时候,那么跟奇数的情况刚好相反,所以得证。
进行状压DP的时候可以先预处理所有长度为
k
k
k的状态是否是回文串,这样能减少DP的时候检测回文串的时间。设DP函数为
d
p
i
,
j
dp_{i,j}
dpi,j表示在第
i
i
i个位置且前面状态为
j
j
j时候能组成的合法字符串数量。那么答案为
∑
j
∈
∣
s
∣
d
p
n
,
j
\sum\limits_{j\in|s|}dp_{n, j}
j∈∣s∣∑dpn,j。
时间复杂度
O ( N ∗ 2 k + 1 ) O(N*2^{k+1}) O(N∗2k+1)
代码
#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
using namespace std;
typedef long long ll;
const int INF = 2147483647;
const int INF2 = 0x3f3f3f3f;
const ll INF64 = 1e18;
const double INFD = 1e30;
const double EPS = 1e-6;
const double PI = 3.1415926;
const ll MOD = 1e9 + 7;
int n, m, k;
int CASE;
const int MAXN = 1005;
int dp[405][1 << 11];
int valid[13][1 << 11];
inline int geti(int x, int i) { return (x >> i) & 1; }
inline bool check(int x, int z) {
for (int i = 0; i < z / 2; i++) {
if (geti(x, i) != geti(x, z - i - 1)) return false;
}
return true;
}
void init() {
for (int j = 1; j <= 12; j++) {
for (int i = 0; i < (1 << j); i++) {
valid[j][i] = check(i, j);
}
}
}
int main() {
#ifdef LOCALLL
freopen("in", "r", stdin);
freopen("out", "w", stdout);
#endif
init();
int T;
scanf("%d", &T);
while (T--) {
scanf("%d %d", &n, &k);
if (k == 1) {
printf("0\n");
continue;
}
memset(dp, 0, sizeof(dp));
int ans = 0;
dp[0][0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 0; j < (1 << (min(i, k))); j++) {
if (!dp[i - 1][j]) continue;
for (int x = 0; x < 2; x++) {
// 取得最后k位状态
int nexxt = ((j << 1) | x) & ((1 << k) - 1);
if (i >= k && valid[k][nexxt]) continue;
if (i > k && valid[k + 1][(j << 1) | x]) continue;
dp[i][nexxt] += dp[i - 1][j];
dp[i][nexxt] %= MOD;
}
}
}
for (int i = 0; i < (1 << k); i++) {
ans += dp[n][i];
ans %= MOD;
}
printf("%d\n", ans);
}
return 0;
}