题目链接:http://hihocoder.com/problemset/problem/1877
因为近似匹配最多只能有1个字符不匹配,所以向ac自动机里加入n+1个模式串(还有一个原串)。然后用AC自动机DP,insert之后构建fail边。
如果要让dpij代表前i个字符到达第j个节点并且至少近似匹配一次的话,转移起来会非常麻烦,所以就反过来统计前i个字符到达第j个节点没有近似匹配的字符串的数量,最后用2^m-dp【m】【j】(j为结点数)即为答案。
#include<bits/stdc++.h>
using namespace std;
const int maxn=100+10;///多个模式串长度
const int lettersize=2;///看种类
char s[maxn];
long long dp[50][2000];
const int MAX=2000+10;
struct Trie
{
int next[MAX][lettersize],fail[MAX],end[MAX];
int root,L;
int newnode()
{
for (int i=0;i<lettersize;i++)
next[L][i]=-1;
end[L++]=0;
return L-1;
}
void init()
{
L=0;
root=newnode();
}
void insert(char *buf)
{
int len=strlen(buf);
int now=root;
for (int i=0;i<len;i++)
{
if (next[now][buf[i]-'0']==-1)
next[now][buf[i]-'0']=newnode();
now=next[now][buf[i]-'0'];
}
end[now]++;
}
void build ()
{
queue<int >Q;
fail[root]=root;
for (int i=0;i<lettersize;i++)
if (next[root][i]==-1)
next[root][i]=root;
else
{
fail[next[root][i]]=root;
Q.push(next[root][i]);
}
while (!Q.empty())
{
int now=Q.front();
Q.pop();
for (int i=0;i<lettersize;i++)
if (next[now][i]==-1)
next[now][i]=next[fail[now]][i];
else
{
fail[next[now][i]]=next[fail[now]][i];
Q.push(next[now][i]);
}
}
}
}ac;
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
ac.init();
int n,m;
scanf("%d%d",&n,&m);
scanf("%s",s);
if(n>m)
{
printf("0\n");
continue;
}
else if(n==m)
{
printf("%d\n",n+1);
continue;
}
ac.insert(s);
for(int i=0;i<n;i++)
{
if(s[i]=='0')
{
s[i]='1';
ac.insert(s);
s[i]='0';
}
else if(s[i]=='1')
{
s[i]='0';
ac.insert(s);
s[i]='1';
}
}
ac.build();
for(int i=0;i<ac.L;i++)
for(int j=0;j<2;j++)
// printf("next[%d][%d]=%d\n",i,j,ac.next[i][j]);
for(int i=0;i<=m;i++)
for(int j=0;j<ac.L;j++)
dp[i][j]=-1;
dp[0][0]=1;
for(int i=1;i<=m;i++)
{
for(int j=0;j<ac.L;j++)
{
if(dp[i-1][j]==-1|| ac.end[j])
continue;
// printf("dp %d %d=%lld->\n",i-1,j,dp[i-1][j]);
for(int k=0;k<2;k++)
{
int son=ac.next[j][k];
if(ac.end[son])
continue;
if(dp[i][son]==-1)
dp[i][son]=dp[i-1][j];
else
dp[i][son]+=dp[i-1][j];
// printf("dp %d %d=%lld\n",i,son,dp[i][son]);
}
}
}
long long ans=1;
for(int i=1;i<=m;i++)
ans*=2;
for(int j=1;j<ac.L;j++)
{
// printf("dp[%d][%d]=%lld\n",m,j,dp[m][j]);
if(dp[m][j]>0)
ans-=dp[m][j];
}
printf("%lld\n",ans);
}
return 0;
}