题目大意:
给定m个01串,问有多少长度为2n的反回文01串,包含这m个串。
一个01串被称为反回文的,当且仅当
∀
i
∈
[
1
,
n
]
,
s
[
i
]
̸
=
s
[
n
−
i
+
1
]
\forall i\in[1,n],s[i]\not=s[n-i+1]
∀i∈[1,n],s[i]̸=s[n−i+1]。
m
≤
6
,
n
≤
500
,
∣
s
i
∣
≤
100
m\le6,n\le500,|s_i|\le100
m≤6,n≤500,∣si∣≤100
题解:
首先朴素做法是建两个AC自动机。
然后你发现你可以把原串和原串reverse后每位取反的串都插到同一个AC自动机里面。对于每个节点暴力跑出其对应哪些串即可。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const int LEN=110,TR=1210,MXS=(1<<6)+10,N=510;
char str[LEN];queue<int> q;int node_cnt,fa[TR],s[TR];
int a[LEN],fr[TR],ch[TR][2],frv[TR],dp[N][MXS][TR];
inline int ins(int *a,int n,int v)
{
int x=1;
for(int i=1,c;i<=n;fr[ch[x][c]]=x,x=ch[x][c],i++)
if(!ch[x][c=a[i]]) ch[x][c]=++node_cnt,frv[ch[x][c]]=c;
return s[x]|=v;
}
inline int getfail()
{
while(!q.empty()) q.pop();
rep(i,0,1)
{
int &c=ch[1][i];
if(c) fa[c]=1,q.push(c);
else c=1;
}
while(!q.empty())
{
int x=q.front();
s[x]|=s[fa[x]],q.pop();
rep(i,0,1)
{
int &y=ch[x][i],f=fa[x],c=ch[f][i];
if(y) fa[y]=c,q.push(y);else y=c;
}
}
return 0;
}
inline int gett(int x,int ans=0)
{
for(int y=x;fr[x];x=fr[x]) ans|=s[y=ch[y][frv[x]^1]];
return ans;
}
inline int upd(int &x,int y) { return x+=y,(x>=mod?x-=mod:0); }
int main()
{
int m=inn(),n=inn();node_cnt=1;
rep(qwq,1,m)
{
scanf("%s",str+1);int l=(int)strlen(str+1);
rep(i,1,l) a[i]=str[i]-'0';ins(a,l,1<<(qwq-1));
rep(i,1,l) a[i]=!a[i];
rep(i,1,l/2) swap(a[i],a[l-i+1]);
ins(a,l,1<<(qwq-1));
}
getfail(),dp[0][0][1]=1;
int nc=node_cnt,all=(1<<m)-1;
rep(i,0,n-1) rep(j,0,all) rep(k,1,nc)
if(dp[i][j][k]) for(int t=0,p;t<=1;t++)
p=ch[k][t],upd(dp[i+1][j|s[p]][p],dp[i][j][k]);
int ans=0;
rep(i,1,nc)
{
int t=gett(i);
rep(j,0,all) if((j|t)==all) upd(ans,dp[n][j][i]);
}
return !printf("%d\n",ans);
}