题意:从0到9都用4位的二进制数表示,叫做BCD码,给出n个模式串,是不能出现的串,给出整数A和整数B,A <= B,问从A到B所有数字化为化成BCD码,问多少个数字的BCD码中没有任何模式串。
题解:0 <= A <= B <= 2^200,位数有200位,可以用数位dp来处理,先用自动机把所有模式串都存到trie图中,然后预处理出bcd[i][j]表示节点i添加的数字是j将会跳到的节点,在数位dp的时候可以直接判断是否可以添加数字i。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define ll long long
using namespace std;
const int N = 2050;
const int MOD = 1000000009;
int Next[N][2], val[N], fail[N], sz, n, arr[205], bcd[N][10];
ll f[205][N];
char str[N], A[N], B[N];
void init() {
memset(Next[0], 0, sizeof(Next[0]));
val[0] = 0;
sz = 1;
}
void insert(char *s) {
int u = 0, len = strlen(s);
for (int i = 0; i < len; i++) {
int k = s[i] - '0';
if (!Next[u][k]) {
memset(Next[sz], 0, sizeof(Next[sz]));
val[sz] = 0;
Next[u][k] = sz++;
}
u = Next[u][k];
}
val[u] = 1;
}
void getFail() {
queue<int> Q;
fail[0] = 0;
for (int i = 0; i < 2; i++)
if (Next[0][i]) {
fail[Next[0][i]] = 0;
Q.push(Next[0][i]);
}
while (!Q.empty()) {
int u = Q.front();
Q.pop();
if (val[fail[u]])
val[u] = 1;
for (int i = 0; i < 2; i++) {
if (!Next[u][i])
Next[u][i] = Next[fail[u]][i];
else {
fail[Next[u][i]] = Next[fail[u]][i];
Q.push(Next[u][i]);
}
}
}
}
int solve2(int node, int a) {
if (val[node]) return -1;
int u = node;
for (int i = 3; i >= 0; i--) {
if (val[Next[u][1 & (a >> i)]]) return -1;
u = Next[u][1 & (a >> i)];
}
return u;
}
void init2() {
for (int i = 0; i < sz; i++)
for (int j = 0; j < 10; j++)
bcd[i][j] = solve2(i, j);
}
ll dp(int pos, int pre, bool limit, bool lead) {
if (pos == -1)
return 1;
if (!limit && f[pos][pre] != -1)
return f[pos][pre];
ll res = 0;
int end = limit ? arr[pos] : 9;
for (int i = 0; i <= end; i++) {
if (i == 0 && lead)
res = (res + dp(pos - 1, pre, limit && i == end, 1)) % MOD;
else if (bcd[pre][i] != -1)
res = (res + dp(pos - 1, bcd[pre][i], limit && i == end, 0)) % MOD;
}
if (!limit && !lead)
f[pos][pre] = res;
return res;
}
ll solve(char *s) {
int len = strlen(s);
for (int i = 0; i < len; i++)
arr[i] = s[len - 1 - i] - '0';
return dp(len - 1, 0, 1, 1);
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
init();
memset(f, -1, sizeof(f));
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", str);
insert(str);
}
getFail();
init2();
scanf("%s%s", A, B);
int len = strlen(A);
for (int i = len - 1; i >= 0; i--) {
if (A[i] > '0') {
A[i]--;
break;
}
else A[i] = '9';
}
ll res2 = solve(A);
ll res1 = solve(B);
printf("%lld\n", (res1 - res2 + MOD) % MOD);
}
return 0;
}