AC自动机上DP
思路就是先统计长度与给定串长度相同的方案数,再统计长度小于给定串长度的方案数
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#define N 2000+5
#define MOD 1000000007
using namespace std;
char str[7000000], *ch = str;
inline void read(int &x)
{
int opt(1);
for (;*ch < '0' || *ch > '9'; ++ch)if (*ch == '-')opt = -1;
for (x = 0;*ch >= '0' && *ch <= '9'; ++ch)
x = (x << 3) + (x << 1) + *ch - '0';
x *= opt;
}
struct node
{
node *fail, *s[26];
bool danger;
}Trie[N], *R;
int n, m, tot, l, f[N][N][2], i, j, k;
char s[N], c[N];
queue<node*> Q;
inline void add()
{
for (;*ch < '0' || *ch > '9'; ++ch);
for (l = 0;*ch >= '0' && *ch <= '9'; ++ch)
c[++l] = *ch;
node *p = R;
for (i = 1;i <= l; ++i)
{
int x = c[i] - '0';
if (!p->s[x])p->s[x] = &Trie[++tot];
p = p->s[x];
}
p->danger++;
}
void bfs()
{
for (i= 0;i <= 9; ++i)
if (R->s[i])
R->s[i]->fail = R, Q.push(R->s[i]);
else
R->s[i] = R;
while (!Q.empty())
{
node *x = Q.front();
Q.pop();
for (i = 0;i <= 9; ++i)
{
if (x->s[i])
{
x->s[i]->fail = x->fail->s[i];
x->s[i]->danger |= x->s[i]->fail->danger;
Q.push(x->s[i]);
}
else
x->s[i] = x->fail->s[i];
}
}
}
int main()
{
fread(str, 1, 7000000, stdin);
for (;*ch >= '0' && *ch <= '9'; ++ch)
s[++l] = *ch - '0';n = l;
read(m);
R = &Trie[++tot];
for (i= 0;i <= 9; ++i)
R->s[i] = &Trie[++tot];
while (m--)
add();
bfs();
for (i = 1;i <= s[1]; ++i)
if (!R->s[i]->danger)
f[1][R->s[i]-Trie][i==s[1]] = 1;
for (i = 1;i <= n-1; ++i)
for (j = 1;j <= tot; ++j)
{
for (k = 0;k <= s[i+1]; ++k)if (Trie[j].s[k] ? !Trie[j].s[k]->danger : 0)
(f[i+1][Trie[j].s[k]-Trie][k==s[i+1]] += f[i][j][1]) %= MOD;
for (k = 0;k <= 9; ++k)if (Trie[j].s[k] ? !Trie[j].s[k]->danger : 0)
(f[i+1][Trie[j].s[k]-Trie][0] += f[i][j][0]) %= MOD;
}
int ans = 0;
for (i = 1;i <= tot; ++i)
(ans += f[n][i][0]) %= MOD, (ans += f[n][i][1]) %= MOD;
memset(f, 0, sizeof(f));
for (i = 1;i <= 9; ++i)
if (!R->s[i]->danger)
f[1][R->s[i]-Trie][0] = 1;
for (i = 1;i <= n-2; ++i)
for (j = 1;j <= tot; ++j)
for (k = 0;k <= 9; ++k)
if (Trie[j].s[k] ? !Trie[j].s[k]->danger : 0)
(f[i+1][Trie[j].s[k]-Trie][0] += f[i][j][0]) %= MOD;
for (i = 1;i <= n-1; ++i)
for (j = 1;j <= tot; ++j)
(ans += f[i][j][0]) %= MOD;
printf("%d\n", ans);
return 0;
}