题目大意,找长度为n的不包含某些字符串的数目。
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <queue>
using namespace std;
const int maxn=200;
const long long mod=100000;
struct node
{
int next[4],flag,fail;
void init()
{
for (int i=0;i<4;i++)
next[i]=-1;
fail=-1;
flag=0;
}
}node[maxn];
char str[20];
int s[20];
int cnt;
long long a[maxn][maxn],ans[maxn][maxn];
int n,m;
void init()
{
cnt=1;
node[0].init();
memset(a,0,sizeof(a));
memset(ans,0,sizeof(ans));
ans[0][0]=1;
}
void change(char str[],int s[],int len)
{
for (int i=0;i<len;i++)
{
if (str[i]=='A') s[i]=0;
if (str[i]=='C') s[i]=1;
if (str[i]=='G') s[i]=2;
if (str[i]=='T') s[i]=3;
}
}
void insert(int s[],int len)
{
int root=0;
for (int i=0;i<len;i++)
{
if(node[root].next[s[i]]==-1)
{
node[cnt].init();
node[root].next[s[i]]=cnt++;
}
root=node[root].next[s[i]];
}
node[root].flag=1;
}
void getfail()
{
queue<int> q;
q.push(0);
while (!q.empty())
{
int root=q.front();
q.pop();
for (int i=0;i<4;i++)
if (node[root].next[i]!=-1)
{
int now=node[root].next[i];
int tmp=node[root].fail;
while (tmp!=-1&&node[tmp].next[i]==-1)
tmp=node[tmp].fail;
if (tmp!=-1)
{
node[now].fail=node[tmp].next[i];
node[now].flag|=node[node[now].fail].flag;
}
else node[now].fail=0;
q.push(now);
}
else
{
int tmp=node[root].fail;
while (tmp!=0&&node[tmp].next[i]==-1)
tmp=node[tmp].fail;
if (node[tmp].next[i]!=-1)
node[root].next[i]=node[tmp].next[i];
else node[root].next[i]=0;
}
}
}
void getM()
{
for (int i=0;i<cnt;i++)
for (int j=0;j<4;j++)
if (node[node[i].next[j]].flag==0)
a[i][node[i].next[j]]++;
}
void multiply()
{
long long res[200][200];
memset(res,0,sizeof(res));
for (int i=0;i<cnt;i++)
for (int j=0;j<cnt;j++)
for (int k=0;k<cnt;k++)
{
res[i][j]+=a[i][k]*a[k][j];
res[i][j]%=mod;
}
for (int i=0;i<cnt;i++)
for (int j=0;j<cnt;j++)
a[i][j]=res[i][j];
}
void work()
{
long long res[200][200];
memset(res,0,sizeof(res));
for (int i=0;i<cnt;i++)
for (int j=0;j<cnt;j++)
for (int k=0;k<cnt;k++)
{
res[i][j]+=ans[i][k]*a[k][j];
res[i][j]%=mod;
}
for (int i=0;i<cnt;i++)
for (int j=0;j<cnt;j++)
ans[i][j]=res[i][j];
}
int main()
{
// freopen("in.txt","r",stdin);
while (~scanf("%d%d",&m,&n))
{
init();
for (int i=0;i<m;i++)
{
scanf("%s",str);
change(str,s,strlen(str));
insert(s,strlen(str));
}
getfail();
getM();
// cout<<"hi"<<endl;
// cout<<cnt<<endl;
// for (int i=0;i<cnt;i++)
// {
// for (int j=0;j<4;j++)
// cout<<node[i].next[j]<<" ";
// cout<<node[i].fail<<" "<<node[i].flag;
// cout<<endl;
// }
if (n==0) {cout<<0<<endl;continue;}
while (n>0)
{
if (n%2)
{
work();
}
n/=2;
multiply();
}
long long res=0;
for (int i=0;i<cnt;i++)
for (int j=0;j<cnt;j++)
if (!node[j].flag)
{
res+=ans[i][j];
res%=mod;
}
cout<<res<<endl;
}
return 0;
}