【POJ 2778】DNA Sequence
Tags:AC自动机 齐次马氏链 矩阵快速幂 Trie图
题目描述
给定 m m m 个仅包含字符 A 、 T 、 C 、 G A、T、C、G A、T、C、G 的串(称为病毒串),再给定一个整数 n n n,作为待生成串的长度。
要求待生成串不能有子串是病毒串,问其共有多少种排列可能( % 100000 \%\ 100000 % 100000)(待生成串也只能包含 A 、 T 、 C 、 G A、T、C、G A、T、C、G)。
输入
仅一组。第一行两个正整数 m , n m,n m,n,分别代表病毒串个数和待生成串的长度。
接下来 m m m 行每行一个字符串,表示病毒串。
范围:
0
≤
m
≤
10
,
0
≤
n
≤
2
×
1
0
9
0≤m≤10,0≤n≤2 \times 10^9
0≤m≤10,0≤n≤2×109,病毒串的长度皆
≤
10
\le 10
≤10。
输出
一个非负整数,不含病毒串子串的排列可能数(
%
100000
\%\ 100000
% 100000)。
输入样例
4 3
AT
AC
AG
AA
输出样例
36
分析
思路
本题的思路启发于齐次马尔科夫链(定义就不列出了(大家随机过程都学的很好了
相当于需要构造一个字符串,构造的方式就是在 Trie 上跑 n n n 条边然后回到根结点。
当然在整个过程中不允许经过某些结点(该点对应的串包含某个病毒串)。
于是题意就化为:不经过某些特定的结点,共有几种跑法,因此用 邻接矩阵 做一个 快速幂 就行。
AC自动机的关键部分
① 实际上应该建立一个 Trie图(方便获取这个有向图的邻接矩阵)。
② 要在 BFS 建树时顺便把禁止经过的结点打标记(在这之后才能建邻接矩阵):
- 如果某个结点对应的串就是病毒串,显然它需要打标记。
- 如果某个结点起头的 f a i l fail fail 链上存在被打标记的点,则该结点也需打标记(说明曾经走过的某个后缀是病毒串)。
【CGWR】注意,这里又要对 f a i l fail fail 链做特殊处理。所以需要我们在建 T r i e Trie Trie 图 B F S BFS BFS 的时候即时判断、处理。(并且由于 B F S BFS BFS 是自顶向下的,所以也不会有遗漏。更不必在 B F S BFS BFS 之后再逐个检查 f a i l fail fail 链)
邻接矩阵快速幂的一个小优化
如果某些行列都是 0 0 0,那么作乘方的时候那些行/列将永远是 0 0 0。
因此可以将方阵进行适当的缩小(如果第 k k k 行之下、第 k k k 列之右全为 0 0 0,则可以整体 “剥去” ,只留下左上角的 k k k 阶顺序主子式)。
时间复杂度:
- 记自动机上总结点数为 N , N ≤ ∑ i = 1 n l e n i ≤ 100 N,N\le \sum\limits_{i=1}^{n} len_i \le 100 N,N≤i=1∑nleni≤100,待生成串长度为 M M M
- 插入病毒串进 T r i e Trie Trie, O ( ∑ i = 1 n l e n i ) O(\sum\limits_{i=1}^{n} len_i) O(i=1∑nleni) 的
- 建 T r i e Trie Trie 图(同时打标记) O ( N ) O(N) O(N) 的。
- 建邻接矩阵 O ( N ) O(N) O(N) 的。
- 快速幂 O ( N 3 log M ) O(N^3\log M) O(N3logM) 的 。
总时间复杂度 O ( N 3 log M ) O(N^3\log M) O(N3logM), N N N 不大,还行。
AC代码
#include <cstdio>
#include <cstring>
#define sc(x) {register char _c=getchar(),_v=1;for(x=0;_c<48||_c>57;_c=getchar())if(_c==45)_v=-1;for(;_c>=48&&_c<=57;x=(x<<1)+(x<<3)+_c-48,_c=getchar());x*=_v;}
#define MA 4
#define MOD 100000
#define MN 107
typedef long long vint;
class Mat
{
public:
int R, C;
vint m[MN][MN];
public:
Mat(const int R, const int C) :
R(R), C(C) { }
inline void clear(void)
{
memset(m, 0, sizeof(m));
}
void init(void)
{
memset(m, 0, sizeof(m));
for (int i=0; i<R; ++i)
m[i][i] = 1;
}
Mat operator *(const Mat &o) const
{
Mat ret(R, o.C);
for (int r=0; r<R; ++r)
{
for (int c=0; c<o.C; ++c)
{
ret.m[r][c] = *m[r] * (*o.m)[c];
for (int k=1; k<C; ++k)
{
ret.m[r][c] += m[r][k] * o.m[k][c];
if (ret.m[r][c] >= MOD)
ret.m[r][c] %= MOD;
}
}
}
return ret;
}
const Mat & operator *=(const Mat &o)
{
Mat tp(R, R);
for (int r=0; r<R; ++r)
{
for (int c=0; c<R; ++c)
{
tp.m[r][c] = *m[r] * (*o.m)[c];
for (int k=1; k<R; ++k)
{
tp.m[r][c] += m[r][k] * o.m[k][c];
if (tp.m[r][c] >= MOD)
tp.m[r][c] %= MOD;
}
}
}
*this = tp;
return *this;
}
const Mat &pow(int n)
{
Mat a(*this);
this->init();
while (n)
{
if(n&1)
this->operator *=(a);
a *= a;
n >>= 1;
}
return *this;
}
inline vint * operator [](const unsigned x)
{
return m[x];
}
};
class ACA
{
private:
const int ROOT;
int hs['T' + 1];
struct
{
int next[MA], fail;
bool bad;
} t[MN];
int tot;
int queue[MN];
Mat mgraph;
public:
ACA(void) : ROOT(0), mgraph(0, 0) { }
inline void init(void)
{
hs['A'] = 0, hs['T'] = 1;
hs['C'] = 2, hs['G'] = 3;
tot = 0;
t[ROOT].fail = -1;
}
void insert(const char *s)
{
int now = ROOT;
for (const char *p=s; *p; ++p)
{
if (!t[now].next[hs[*p]])
t[now].next[hs[*p]] = ++tot;
now = t[now].next[hs[*p]];
}
t[now].bad = true;
}
void bfs()
{
int head = 0, tail = 0;
for (int i=0; i<MA; ++i)
if (t[ROOT].next[i])
queue[tail++] = t[ROOT].next[i], t[t[ROOT].next[i]].fail = 0;
while (head != tail)
{
int u = queue[head++];
if (t[t[u].fail].bad)
t[u].bad = true;
for (int i=0; i<MA; ++i)
{
if (t[u].next[i])
{
queue[tail++] = t[u].next[i];
t[t[u].next[i]].fail = t[t[u].fail].next[i];
}
else
t[u].next[i] = t[t[u].fail].next[i];
}
}
}
void build()
{
int max_rc = 0;
for (int u=0; u<=tot; ++u)
{
for (int i=0; i<MA; ++i)
{
const int v = t[u].next[i];
if (t[u].bad || t[v].bad)
continue;
++mgraph[u][v];
if (max_rc < u)
max_rc = u;
if (max_rc < v)
max_rc = v;
}
}
mgraph.R = mgraph.C = max_rc + 1;
}
vint get_ans(const int n)
{
vint sum = 0;
mgraph.pow(n);
for (int c=0; c<mgraph.C; ++c)
sum += mgraph[0][c], sum %= MOD;
return sum;
}
};
ACA ac;
char s[MN];
int main()
{
int k, n;
sc(k)sc(n)
ac.init();
while (k--)
scanf("%s", s), ac.insert(s);
ac.bfs();
ac.build();
vint ans = ac.get_ans(n);
printf("%lld\n", ans);
return 0;
}