题目大意:
众所周知DNA序列只含A、T、C、G四种基因,这种情况很利于科学分析基因密码和疾病缘由,在疾病控制领域,比如一种动物基因序列中含有ATC序列则表示该动物可能会得基因疾病,则这种基因片段则成为病变片段,现在科学家就找到了几种这样的病变片段,现在的问题是如何求出不含这些病变片段的所有可能基因序列。
现只有一个测例,测例中指定m条病变片段(0 ≤ m ≤ 10),以及该种动物DNA序列的长度n(1 ≤ n ≤ 2,000,000,000),接着会给出每条病变片段序列,所有序列都是由A、T、C、G组成,现求出长度为n的所有不含这些病变片段的DNA序列的总数,由于最终结果可能很大,因此结果模除100,000后再输出。
注释代码:
/*
* Problem ID : POJ 2778 DNA Sequence
* Author : Lirx.t.Una
* Language : C++
* Run Time : 313 ms
* Run Memory : 196 KB
*/
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
//思路:
//1. 用病变基因构造AC自动机(对所有病变单词结尾进行标记,被标记的状态就成为病变状态)
//这就表示在用AC自动机生成长度为n的句子时一路上绝对不能经过病变状态,如果经过
//病变状态的话就代表最后生成的句子中一定含有病变片段,如此一来只要想方设法去掉病变
//状态来重新构造AC自动机就行了
//而实际上不用那么复杂,也不需要重新构造新的AC自动机,而使用矩阵快速幂就可以使问题
//迎刃而解
//2. 在stat数组中剔除所有病变状态而得到sz个合法状态(离散化),并对所有合法状态标号1 ~ sz(sz也就
//是矩阵的宽度了,然后考察以每个状态为左部的产生式,假设其中有这样两个产生式:
//i -(C)-> j,i -(G)-> j,由于每条产生式都表示1步推导,因此这里就表示由i号状态到j号
//状态走1步的话有两种方法,一种是经过C到j,另一种是经过G到j,然后将任意i->j走一步的
//方法数初始化的矩阵m中,即m[i][j]初始化成状态i到状态j走1步的方法数,当然初始状态一定
//不是病变状态,并且将初始状态设为1号状态
//3. 求a = m^n(使用矩阵快速幂),初始化后,如果m = m × m,则结果中m[i][j]就表示从i号状态
//到j号状态走两步的方法数了
//依此类推m = m ^ n后,结果中m[i][j]就表示从i号状态到j号状态走n步的方法数了,由于病变状态
//不在矩阵中,因此最终结果一定不包含病变状态数
//4. 结果a中,a[1][i]就表示从1号状态(也就是初始状态)到i号走n步的方法数,那么∑(a[1][i])就是
//最终答案了
//模除因子
#define MOD 100000
//gene number
//基本基因的数量
//就A、T、C、G 4种
#define GEN 4
//maximum disease segment length
//病变片段的最大长度
#define MAXDSEGLEN 11
//maximum number of state
//状态数
#define MAXSTATN 45
using namespace std;
typedef __int64 llg;
typedef int (*MAT)[MAXSTATN];//矩阵类型
struct Node {
//为每种状态赋予一个id号
//再构造自动机时表示单词结尾
//在利用自动机产生句子的时候将被更新成合法状态的标号
int id;
int out[GEN];
int fail;
};
char ds[MAXDSEGLEN];//disease segment,存放病变片段
Node stat[MAXSTATN];
int tsn;
queue<int> q;
//main matrix,主矩阵
//初始时存放各个状态之间经一步转换的方法数
int m[MAXSTATN][MAXSTATN];
//temporary matrix
//临时矩阵,在做矩阵乘法时用于存放中间值
int t[MAXSTATN][MAXSTATN];
//answer matrix
//答案矩阵,a[i][j]表示从i状态到j状态走n步的方法数
int a[MAXSTATN][MAXSTATN];
int sz;//size,矩阵的大小,即宽度
int
xlat(char c) {//translate
//将基因字母转化为响应的数组下标
switch (c) {
case 'A' : return 0;
case 'T' : return 1;
case 'C' : return 2;
default : return 3;
}
}
void
insert(char *w) {//trie树插入
int s;
int g;//gene,当前输入字母
s = 0;
while ( *w ) {
g = xlat( *(w++) );
if ( !stat[s].out[g] ) {
stat[s].out[g] = ++tsn;
s = tsn;
continue;
}
s = stat[s].out[g];
}
stat[s].id = ~0;//单词结尾用-1表示(即每一位都是1的数),即表示病变状态
}
void
bd_AC_auto(void) {
int s, ss;//当前状态和当前状态推出的子状态
int r, rr;//fail指针回溯状态,以及回溯状态推出的子状态
int g;//gene,即自动机的当前弧
for ( g = 0; g < GEN; g++ )
if ( ss = stat[0].out[g] )
q.push(ss);
while ( !q.empty() ) {
s = q.front();
q.pop();
for ( g = 0; g < GEN; g++ )
if ( ss = stat[s].out[g] ) {
r = s;
while ( r = stat[r].fail )
if ( rr = stat[r].out[g] ) {
stat[ss].fail = rr;
//注意每连上一个fail指针就得
//更新一下其id号(此时id表示单词结尾)
//下面这句和这里的if语句等价,只不过更加简洁和快捷
//if ( stat[rr].id == -1 )
// stat[ss].id = -1;
stat[ss].id |= stat[rr].id;
break;
}
if ( !r )
if ( rr = stat[0].out[g] ) {
stat[ss].fail = rr;
stat[ss].id |= stat[rr].id;
}
else
stat[ss].fail = 0;
q.push(ss);
}
else//注意!!!这里使用自动机产生句子而不是用自动机扫描句子
//因此为了达到指定长度的句子,在产生句子的时候不能中断
//所以要将空的out连到s的fail.out上
//这样就仿佛在用自动机扫描句子,只要句子不结束,句子就在不停地
//向后扫描
stat[s].out[g] = stat[ stat[s].fail ].out[g];
}
}
void
setm(MAT m) {//set main matrix
//初始化主矩阵
//即离散化
int s;
int i, j;//i, j都为健康状态的id号
int g;
sz = 0;
for ( s = 0; s <= tsn; s++ )//剔除所有病变状态
if ( !stat[s].id )
stat[s].id = ++sz;//赋予每个健康状态一个id号,其中初始状态id一定是1
//最终sz也是健康状态的总数
for ( s = 0; s <= tsn; s++ )//用健康状态来初始化主矩阵
if ( ( i = stat[s].id ) != ~0 )
for ( g = 0; g < GEN; g++ )
if ( ( j = stat[ stat[s].out[g] ].id ) != ~0 )
m[i][j]++;//统计i号到j号走1步的走法总数
}
void
mul( MAT m1, MAT m2 ) {//矩阵乘法
//结果覆盖到m1中
int i, j, k;
memset(t, 0, sizeof(t));
for ( i = 1; i <= sz; i++ )
for ( k = 1; k <= sz; k++ )
if ( m1[i][k] )
for ( j = 1; j <= sz; j++ ) {
//注意大数模除
t[i][j] += (int)( (llg)m1[i][k] * (llg)m2[k][j] % (llg)MOD );
t[i][j] %= MOD;
}
memcpy(m1, t, sizeof(t));
}
int
fstpow( MAT m, int n, MAT a ) {//矩阵快速幂
//a=m^n
int i;
int ans;
for ( i = 1; i <= sz; i++ )
a[i][i] = 1;
while ( n ) {
if ( n & 1 )
mul( a, m );
mul( m, m );
n >>= 1;
}
for ( ans = 0, i = 1; i <= sz; i++ )
ans += a[1][i];
return ans % MOD;
}
int
main() {
int dm, n;//病变片段数和DNA的总长度
int i;
int ans;
scanf("%d%d", &dm, &n);
tsn = 0;
while ( dm-- ) {
scanf("%s", ds);
insert(ds);
}
bd_AC_auto();
setm(m);
printf("%d\n", fstpow( m, n, a ));
return 0;
}
无注释代码:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#define MOD 100000
#define GEN 4
#define MAXDSEGLEN 11
#define MAXSTATN 45
using namespace std;
typedef __int64 llg;
typedef int (*MAT)[MAXSTATN];
struct Node {
int id;
int out[GEN];
int fail;
};
char ds[MAXDSEGLEN];
Node stat[MAXSTATN];
int tsn;
queue<int> q;
int m[MAXSTATN][MAXSTATN];
int t[MAXSTATN][MAXSTATN];
int a[MAXSTATN][MAXSTATN];
int sz;
int
xlat(char c) {
switch (c) {
case 'A' : return 0;
case 'T' : return 1;
case 'C' : return 2;
default : return 3;
}
}
void
insert(char *w) {
int s;
int g;
s = 0;
while ( *w ) {
g = xlat( *(w++) );
if ( !stat[s].out[g] ) {
stat[s].out[g] = ++tsn;
s = tsn;
continue;
}
s = stat[s].out[g];
}
stat[s].id = ~0;
}
void
bd_AC_auto(void) {
int s, ss;
int r, rr;
int g;
for ( g = 0; g < GEN; g++ )
if ( ss = stat[0].out[g] )
q.push(ss);
while ( !q.empty() ) {
s = q.front();
q.pop();
for ( g = 0; g < GEN; g++ )
if ( ss = stat[s].out[g] ) {
r = s;
while ( r = stat[r].fail )
if ( rr = stat[r].out[g] ) {
stat[ss].fail = rr;
stat[ss].id |= stat[rr].id;
break;
}
if ( !r )
if ( rr = stat[0].out[g] ) {
stat[ss].fail = rr;
stat[ss].id |= stat[rr].id;
}
else
stat[ss].fail = 0;
q.push(ss);
}
else
stat[s].out[g] = stat[ stat[s].fail ].out[g];
}
}
void
setm(MAT m) {
int s;
int i, j;
int g;
sz = 0;
for ( s = 0; s <= tsn; s++ )
if ( !stat[s].id )
stat[s].id = ++sz;
for ( s = 0; s <= tsn; s++ )
if ( ( i = stat[s].id ) != ~0 )
for ( g = 0; g < GEN; g++ )
if ( ( j = stat[ stat[s].out[g] ].id ) != ~0 )
m[i][j]++;
}
void
mul( MAT m1, MAT m2 ) {
int i, j, k;
memset(t, 0, sizeof(t));
for ( i = 1; i <= sz; i++ )
for ( k = 1; k <= sz; k++ )
if ( m1[i][k] )
for ( j = 1; j <= sz; j++ ) {
t[i][j] += (int)( (llg)m1[i][k] * (llg)m2[k][j] % (llg)MOD );
t[i][j] %= MOD;
}
memcpy(m1, t, sizeof(t));
}
int
fstpow( MAT m, int n, MAT a ) {
int i;
int ans;
for ( i = 1; i <= sz; i++ )
a[i][i] = 1;
while ( n ) {
if ( n & 1 )
mul( a, m );
mul( m, m );
n >>= 1;
}
for ( ans = 0, i = 1; i <= sz; i++ )
ans += a[1][i];
return ans % MOD;
}
int
main() {
int dm, n;
int i;
int ans;
scanf("%d%d", &dm, &n);
tsn = 0;
while ( dm-- ) {
scanf("%s", ds);
insert(ds);
}
bd_AC_auto();
setm(m);
printf("%d\n", fstpow( m, n, a ));
return 0;
}