题目地址:http://poj.org/problem?id=2778
题目意思:
给你M个DNA的小序列
然后要你求出长度为N但是不含给出的M个小DNA的情况有多少种
这是一道很好的题目
对算法的要求很高,具体的思路我是在:http://blog.csdn.net/morgan_xww/article/details/7834801学来的,所以可以移步去看原创
我主要说说几个要注意的地方
首先就是AC自动机的创建
由于这个题目我们求的是一个跳转矩阵,所以和之前的匹配不是一个意思
那么在创建AC自动机这个数据结构的时候就要注意
主要的区别在于getfail()这个函数里面,个中真意还是自己体会一下比较好
然后就是快速矩阵幂了
我开始用一个比较粗糙的算法写的,但是直接RE了
后来看DISCUSS里面,反思了一下
因为我是开的long long 的101*101的矩阵
而且是用递归写的,用之前的粗糙的算法在递归的过程中我每一层都申请了矩阵的
那么出于保护现场的原因,我就相当于申请了很多的矩阵,最后,肯定就是内存HOLD不住了
所以,我修改了一下,我不用重新申请,直接从上一层引用,这样就可以大大的节约内存
特别是在层数很多的时候
下面上我的代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int maxnode = 10*10+5;
const int size=4;
const int mod=100000;
struct AC
{
int ch[maxnode][size];
int f[maxnode];
bool val[maxnode];
int sz;
void init()
{
memset(ch[0],-1,sizeof(ch[0]));
sz=1;
val[0]=false;
}
int idx(char c)
{
if(c=='A')
return 0;
else if(c=='C')
return 1;
else if(c=='T')
return 2;
else
return 3;
}
void insert(char *s)
{
int len = strlen(s);
int u=0;
for(int i=0;i<len;i++)
{
int c=idx(s[i]);
if(-1 == ch[u][c])
{
memset(ch[sz],-1,sizeof(ch[sz]));
val[sz]=false;
ch[u][c]=sz++;
}
u=ch[u][c];
}
val[u]=true;
}
void getfail()
{
queue<int> q;
for(int i=0;i<size;i++)
{
if(-1 != ch[0][i])
{
f[ch[0][i]] = 0;
q.push(ch[0][i]);
}
else
ch[0][i] = 0;
}
while(!q.empty())
{
int r=q.front();
q.pop();
//这里我们是要把trie变为一棵跳转的树,不是一棵匹配树
if(val[f[r]])
val[r]=true;
for(int i=0;i<size;i++)
{
int &v=ch[r][i];
if(-1 != v)
{
q.push(v);
f[v]=ch[f[r]][i];
}
else
v=ch[f[r]][i];
}
}
}
};
AC ac;
struct maxtrix
{
long long m[maxnode][maxnode];
};
void maxtrixmul(maxtrix a,maxtrix b,maxtrix &ans)
{
int n=ac.sz;
memset(ans.m,0,sizeof(ans.m));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
for(int k=0;k<n;k++)
{
ans.m[i][j]+=a.m[i][k]*b.m[k][j];
}
ans.m[i][j] = ans.m[i][j]%mod;
}
}
}
void maxtrixpower(int n,maxtrix a,maxtrix &ans)
{
if(n==1)
{
ans=a;
return;
}
if(n&1)
{
int p=n-1;
maxtrixpower(p/2,a,ans);
maxtrixmul(ans,ans,ans);
maxtrixmul(ans,a,ans);
}
else
{
maxtrixpower(n/2,a,ans);
maxtrixmul(ans,ans,ans);
}
}
void build(maxtrix &ans)
{
memset(ans.m,0,sizeof(ans.m));
int n = ac.sz;
for(int i=0;i<n;i++)
{
for(int j=0;j<size;j++)
{
if(!ac.val[i] && !ac.val[ac.ch[i][j]])
{
ans.m[i][ac.ch[i][j]]++;
}
}
}
}
int main()
{
int mm,nm;
while(~scanf("%d%d",&mm,&nm))
{
char op[20];
ac.init();
for(int i=0;i<mm;i++)
{
scanf("%s",op);
ac.insert(op);
}
ac.getfail();
maxtrix mat;
build(mat);
maxtrixpower(nm,mat,mat);
long long cnt=0;
for(int i=0;i<ac.sz;i++)
cnt+=mat.m[0][i];
printf("%I64d\n",cnt%mod);
}
return 0;
}