因为写错了一个变量名,debug了半个小时,又犯这样的错误。以此为戒。。。
题意:给你n个R、m个D,用这些字母组成一个字符串,再给你两个用R、D组成的单词,问你有多少种字符串包含这两个单词,单词可以重叠。
解法:先AC自动机,每个单词节点的val用状压表示,分别是1和2(二进制的01和10)。再用适配路径求出ok[maxnode],表示在每个节点所处的状态。之后就是DP了, dp[r][d][ac.sz][s]中,r是R的个数,d是D的个数,ac.sz是单词节点,s是状态。关于R的状态转移方程:dp[r+1][d][ch[ac.sz][0]][s|ok[ch[ac.sz][0]]]。同理,D的状态转移方程: dp[r][d+1][ch[ac.sz][1]][s|ok[ch[ac.sz][1]]]。
代码:
#include <iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<cmath>
#include<vector>
#define inf 0x3f3f3f3f
#define Inf 0x3FFFFFFFFFFFFFFFLL
#define pi acos(-1.0)
#define eps 1e-8
using namespace std;
const int mod = 1000000007;
const int maxnode = 202;
const int char_size = 2;
struct autoAC
{
int ch[maxnode][char_size], val[maxnode], f[maxnode], last[maxnode];
int sz;
int idx(char c) {if(c=='R') return 0; return 1;}
void init() {memset(ch[0], 0,sizeof ch[0]);sz=1;}
void insert(char* s, int v = 1)
{
int u = 0, n = strlen(s);
for(int i = 0 ; i < n ; ++ i)
{
int c = idx(s[i]);
if(!ch[u][c])
{
memset(ch[sz],0,sizeof ch[sz]);
val[sz]=0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u]=v;
}
void getFail()
{
queue<int> q;
f[0] = 0;
for(int c = 0 ; c < char_size ; ++ c)
{
int u = ch[0][c];
if(u) {f[u]=last[u]=0; q.push(u);}
}
while(!q.empty())
{
int r = q.front(); q.pop();
for(int c = 0 ; c < char_size ; ++ c)
{
int u = ch[r][c];
if(!u) {ch[r][c]=ch[f[r]][c]; continue;}
q.push(u);
int v = f[r];
while(v&&!ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]]?f[u]:last[f[u]];
}
}
}
}ac;
int dp[110][110][maxnode][5];
int main()
{
//freopen("in.txt","r",stdin);
int t;scanf("%d", &t);
while(t--)
{
int r, d;scanf("%d%d", &r, &d);
char s[maxnode];ac.init();
for(int i = 0 ; i < 2 ; ++ i)
{
scanf("%s",s);
ac.insert(s,i+1);
}
ac.getFail();
int ok[maxnode] = {0};
for(int i = 0 ; i < ac.sz ; ++ i)
{
int j = i;
while(j)
{
ok[i]|=ac.val[j];
j = ac.last[j];
}
}
memset(dp, 0, sizeof dp);
dp[0][0][0][0] = 1;
for(int i = 0 ; i <= r ; ++ i)
for(int j = 0 ; j <= d ; ++ j)
for(int k = 0 ; k < ac.sz ; ++ k)
for(int s = 0 ; s < 4 ; ++ s)
{
if(!dp[i][j][k][s]) continue;
if(i<r)
{
int u = ac.ch[k][0]; int sb = ok[u]|s;
dp[i+1][j][u][sb] += dp[i][j][k][s];
if(dp[i+1][j][u][sb]>=mod) dp[i+1][j][u][sb]-=mod;
}
if(j<d)
{
int u = ac.ch[k][1]; int sb = ok[u]|s;
dp[i][j+1][u][sb] += dp[i][j][k][s];
if(dp[i][j+1][u][sb]>=mod) dp[i][j+1][u][sb]-=mod;
}
}
int ans = 0;
for(int i = 0 ; i < ac.sz ; ++ i)
{
ans += dp[r][d][i][3];
ans %= mod;
}
printf("%d\n",ans);
}
return 0;
}