题目大意
给定一个串,长度为m, 询问存在多少长度为n的串,使得与长度m的串对应位置匹配最多只有一个字符不相同的方案数
题目分析
如果我们考虑串进行匹配的化,很难去重
因此我们反过来考虑
考虑匹配不上的情况
那么要匹配的串只有m+1个,所以我们只要考虑当前串的一个后缀是否是m+1个串的前缀即可。
然后通过后缀进行转移
f[i][j] 表示 前i项的后缀为状态j的时候的方案数
那么对于第i+1项,我们考虑 +0 或者 +1 此时得到一个新字符串 k
如果k是要匹配的串,则不算该串的值
否则,就去查找 k的最长后缀满足 m+1个串的前缀的状态
开始我们暴力去查找这个状态,就T了;
所以这个地方考虑用AC自动机,直接去查找fail指针就可以了
代码详解
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e4+50;
string s;
map<string,int>mp;
map<int,string>mp2;
map<string,int>mp3;
int idx=0;
typedef long long ll;
int trie[maxn][26];
int fail[maxn];
int idword[maxn];
int cnt=0;
void insertword(string s,int k)
{
int root = 0;
for(int i=0;i<s.size();i++)
{
int nxt = s[i]-'0';
if(!trie[root][nxt])
{
trie[root][nxt] = ++cnt;
}
root = trie[root][nxt];
}
idword[root] = k;
// cout<<s<<" "<<root<<" "<<k<<endl;
}
void getfail()
{
queue<int>q;
for(int i=0;i<2;i++)
{
if(trie[0][i])
{
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
while(!q.empty())
{
int now = q.front();
q.pop();
for(int i=0;i<2;i++)
{
if(trie[now][i])
{
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else
trie[now][i] = trie[fail[now]][i];
}
}
}
int query(string s)
{
int now = 0,ans = 0;
for(int i=0;i<s.size();i++)
{
if(trie[now][s[i]-'0'])
now = trie[now][s[i]-'0'];
else now = trie[fail[now]][s[i]-'0'];
}
return idword[now];
}
void init(string k)
{
string tmp ;
for(int i=0;i<k.size()-1;i++)
{
tmp += k[i];
if(mp.count(tmp)==0)
{
mp[tmp] = ++idx;
mp2[idx] = tmp;
insertword(tmp,idx);
}
}
}
ll f[44][2001];
int main()
{
int T;
cin>>T;
while(T--)
{
mp.clear(); mp2.clear();mp3.clear();
cnt=0;
memset(trie,0,sizeof(trie));
memset(idword,-1,sizeof(idword));
memset(fail,0,sizeof(fail));
int n,m;
ll all = 1;
cin>>n>>m;
for(int i=1;i<=m;i++) all = all*2;
cin>>s;
idx=0;
init(s);
mp3[s] = 0;
memset(f,0,sizeof(f));
for(int i=0;i<s.size();i++)
{
s[i] = s[i]^1;
mp3[s] = i;
init(s);
s[i] = s[i]^1;
}
getfail();
string k ="1";
int r1 = mp[k];
k = "0";
int r2 = mp[k];
f[1][r1] = 1;
f[1][r2] = 1;
for(int i=1;i<=m;i++)
{
for(int j=1;j<=idx;j++)
{
string tmp = mp2[j];
string a,b;
a = tmp+"0";
if(mp3.count(a)==0)
{
int p = query(a);
f[i+1][p] += f[i][j];
}
a = tmp+"1";
if(mp3.count(a)==0)
{
int p = query(a);
f[i+1][p] += f[i][j];
}
}
}
int len = s.size();
ll sum = 0;
for(int i=1;i<=idx;i++)
{
sum += f[m][i];
}
cout<<all-sum<<endl;
}
return 0;
}