这道题类似于poj 1625,只不过字符集变得更小,只有’A’,’G’,’C’,’T’四个字符。相应的,序列的长度大幅度的增加至2*10^9。假设序列长度为m,总的trie图节点数为sz,原先的dp方法的复杂度为O(m*sz^2)。若仍采用dp的方法,那么总的操作数达到10^11,肯定不可取。因此需要其他的方法。矩阵可以解决这个问题。求出trie图的邻接矩阵,两点之间的权值为这两点间的边数,矩阵中的a[i][j]表示由节点i,到节点j走一步的方法数。此矩阵的n次幂后a’[i][j]即为由节点i走n步到节点j的方法数。因此题目的答案a[root][i]的和,(0<=i<sz)。
同样,要记得去除所有危险节点的行和列,即权值标为0,表示没有这条边。
矩阵的计算过程采用快速幂。复杂度为O(log(n))
更多的解释可以参见poj 1625
下面贴代码:
#include <cstdio>
#include <queue>
#include <cstring>
#define sigma_size 4
#define maxnode 105
#define modnum 100000
using namespace std;
int m, n;
int ch[maxnode][sigma_size];
int val[maxnode];
int sz;
int f[maxnode];
int last[maxnode];
void initial()
{
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
}
int charget(char a)
{
switch(a)
{
case 'A':
return 0;
case 'G':
return 1;
case 'C':
return 2;
case 'T':
return 3;
}
return -1;
}
void insert(char *S)
{
int l = strlen(S);
int u = 0;
for(int i = 0; i < l; i++)
{
int c = charget(S[i]);
if(!ch[u][c])
{
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] = 1;
}
void getfail()
{
queue<int>q;
f[0] = 0;
for(int c = 0; c < sigma_size; c++)
{
int u = ch[0][c];
if(u)
{
f[u] = 0;
q.push(u);
last[u] = 0;
}
}
while(!q.empty())
{
int r = q.front();
q.pop();
for(int c = 0; c < sigma_size; c++)
{
int u = ch[r][c];
if(!u)
{
ch[r][c] = ch[f[r]][c];
continue;
}
q.push(u);
int v = f[r];
while(v && !ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}
long long matrix[maxnode][maxnode];
long long res1[maxnode][maxnode];
long long res2[maxnode][maxnode];
void buildmatrix()
{
memset(matrix, 0, sizeof(matrix));
for(int i = 0; i < sz; i++)
{
if(val[i] || last[i]) //去除非法节点出边
continue;
for(int j = 0 ; j < sigma_size; j++)
{
if(val[ch[i][j]] || last[ch[i][j]])
continue;//去除非法节点入边
matrix[i][ch[i][j]]++;
}
}
}
void mul(long long a[][maxnode], long long b[][maxnode], long long c[][maxnode])
{
memset(c, 0, sizeof(matrix));
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
{
for(int k = 0; k < sz; k++)
c[i][j] += (a[i][k]*b[k][j]);
c[i][j] %= modnum;
}
}
void swap(long long a[][maxnode],long long b[][maxnode])
{
long long tmp;
for(int i = 0; i <sz;i++)
for(int j = 0; j<sz;j++)
{
tmp = a[i][j];
a[i][j] = b[i][j];
b[i][j] = tmp;
}
}
void multiple(int n)
{
if(n == 1)
{
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
res1[i][j] = matrix[i][j];
return;
}
multiple(n / 2);
mul(res1, res1, res2);
if(n % 2)
{
mul(res2, matrix, res1);
}
else
swap(res2, res1);
}
char ban[12];
int main()
{
initial();
scanf("%d%d", &m, &n);
for(int i = 0; i < m; i++)
{
scanf("%s", ban);
insert(ban);
}
getfail();
buildmatrix();
multiple(n);
long long ans = 0;
for(int i = 0; i < sz; i++)
ans += res1[0][i];
printf("%lld\n",ans%modnum);
return 0;
}