Poj 2778 [AC自动机,矩阵乘法]

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/morgan_xww/article/details/7834801

•题意:有m种DNA序列是有疾病的,问有多少种长度为n的DNA序列不包含任何一种有疾病的DNA序列。(仅含A,T,C,G四个字符)
•样例m=4,n=3,{“AA”,”AT”,”AC”,”AG”}
•答案为36,表示有36种长度为3的序列可以不包含疾病

这个和矩阵有什么关系呢???

•上图是例子{“ACG”,”C”},构建trie图后如图所示,从每个结点出发都有4条边(A,T,C,G)
•从状态0出发走一步有4种走法:
  –走A到状态1(安全);
  –走C到状态4(危险);
  –走T到状态0(安全);
  –走G到状态0(安全);
•所以当n=1时,答案就是3
•当n=2时,就是从状态0出发走2步,就形成一个长度为2的字符串,只要路径上没有经过危险结点,有几种走法,那么答案就是几种。依此类推走n步就形成长度为n的字符串。
•建立trie图的邻接矩阵M:

2 1 0 0 1

2 1 1 0 0

1 1 0 1 1

2 1 0 0 1

2 1 0 0 1

M[i,j]表示从结点i到j只走一步有几种走法。

那么M的n次幂就表示从结点i到j走n步有几种走法。

注意:危险结点要去掉,也就是去掉危险结点的行和列。结点3和4是单词结尾所以危险,结点2的fail指针指向4,当匹配”AC”时也就匹配了”C”,所以2也是危险的。

矩阵变成M:

2 1

2 1

计算M[][]的n次幂,然后 Σ(M[0,i]) mod 100000 就是答案。

由于n很大,可以使用二分来计算矩阵的幂

#include <cstdio>
#include <queue>
#include <algorithm>
#include <iostream>
#include <cstring>

using namespace std;

const int MAX_N = 10 * 10 + 5;   //最大结点数:模式串个数 X 模式串最大长度
const int CLD_NUM = 4;           //从每个结点出发的最多边数:本题是4个ATCG

typedef long long MATRIX[MAX_N][MAX_N];

MATRIX mat, mat1, mat2;
long long (*m1)[MAX_N], (*m2)[MAX_N];

class ACAutomaton
{
public:
    int  n;                          //当前结点总数
    int  id['Z'+1];                  //字母x对应的结点编号为id[x]
    int  fail[MAX_N];                //fail指针
    bool tag[MAX_N];                 //本题所需
    int  trie[MAX_N][CLD_NUM];       //trie tree

    void init()
    {
        id['A'] = 0;
        id['T'] = 1;
        id['C'] = 2;
        id['G'] = 3;
    }

    void reset()
    {
        memset(trie[0], -1, sizeof(trie[0]));
        tag[0] = false;
        n = 1;
    }

    //插入模式串s,构造单词树(keyword tree)
    void add(char *s)
    {
        int p = 0;
        while (*s)
        {
            int i = id[*s];
            if ( -1 == trie[p][i] )
            {
                memset(trie[n], -1, sizeof(trie[n]));
                tag[n] = false;
                trie[p][i] = n++;
            }
            p = trie[p][i];
            s++;
        }
        tag[p] = true;
    }

    //用BFS来计算每个结点的fail指针,构造trie树
    void construct()
    {
        queue<int> Q;
        fail[0] = 0;
        for (int i = 0; i < CLD_NUM; i++)
        {
            if (-1 != trie[0][i])
            {
                fail[trie[0][i]] = 0;
                Q.push(trie[0][i]);
            }
            else
            {
                trie[0][i] = 0;    //这个不能丢
            }
        }
        while ( !Q.empty() )
        {
            int u = Q.front();
            Q.pop();
            if (tag[fail[u]])
                tag[u] = true;         //这个很重要,当u的后缀是病毒,u也不能出现
            for (int i = 0; i < CLD_NUM; i++)
            {
                int &v = trie[u][i];
                if ( -1 != v )
                {
                    Q.push(v);
                    fail[v] = trie[fail[u]][i];
                }
                else
                {
                    v = trie[fail[u]][i];
                }
            }
        }
    }

    /* 根据trie树来构建状态转换的邻接矩阵mat[][]
       mat[i][j]表示状态i到状态j有几条边   */
    void buildMatrix()
    {
        memset(mat, 0, sizeof(mat));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < CLD_NUM; j++)
                if ( !tag[i] && !tag[trie[i][j]] )  //tag值为true的结点不能要,因为该结点的状态表示一个病毒
                    mat[i][trie[i][j]]++;
    }
} AC;

void matrixMult(MATRIX t1, MATRIX t2, MATRIX res)
{
    for (int i = 0; i < AC.n; i++)
        for (int j = 0; j < AC.n; j++)
        {
            res[i][j] = 0;
            for (int k = 0; k < AC.n; k++)
            {
                res[i][j] += t1[i][k] * t2[k][j];
            }
            res[i][j] %= 100000;
        }
}

/*
    递归二分计算矩阵的p次幂,结果存在m2[][]中
*/
void matrixPower(int p)
{
    if (p == 1)
    {
        for (int i = 0; i < AC.n; i++)
            for (int j = 0; j < AC.n; j++)
                m2[i][j] = mat[i][j];
        return;
    }

    matrixPower(p/2);          //计算矩阵的p/2次幂,结果存在m2[][]
    matrixMult(m2, m2, m1);    //计算矩阵m2的平方,结果存在m1[][]

    if (p % 2)                 //如果p为奇数,则再计算矩阵m1乘以原矩阵mat[][],结果存在m2[][]
        matrixMult(m1, mat, m2);
    else
        swap(m1, m2);
}

int main()
{
    int  n, m;
    char s[12];

    AC.init();
    cin >> m >> n;
    AC.reset();
    while ( m-- )
    {
        scanf("%s", s);
        AC.add(s);
    }
    AC.construct();
    AC.buildMatrix();

    m1 = mat1;
    m2 = mat2;
    matrixPower(n);

    int ans = 0;
    for (int i = 0; i < AC.n; i++)
        ans += m2[0][i];
    printf("%d\n", ans % 100000);

    return 0;
}


展开阅读全文

没有更多推荐了,返回首页