算法:AC自动机的理解,Trie图的应用。
我的理解:Trie图是建立在AC自动机基础上的,前者比后者有更多的应用。对于AC自动机的理解, 请参考这篇大牛推荐论文,http://www.cs.uku.fi/~kilpelai/BSA05/lectures/slides04.pdf,虽然是英文的,读了一个下午后,确实觉得讲的很明白; 对Trie的理解,请参考王贇的论文http://wenku.baidu.com/view/9df73e3567ec102de2bd89ae.html,写的很好。
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2243
题目描述:给你n个字符串S{s1, s2, ...., sn},只包含小写字母,问你长度不超过L,至少包含其中一个字符串(只有小写字母组成)的单词有多少个。
长度不超过L的单词总数减去不包含S中任意一个字符串的单词数, 即ans = 26^1 + 26^2 + … + 26^L - 所有不包含的。
有因为L最大为2^31,所以此题关键有两点:
求出所有不包含的数目; 快速求出前L次幂的和。
1)求所有不包含的数目:再理解AC自动机和Trie图之后,根据Trie图,我们建立邻接矩阵,我们就可以通过矩阵的前L次幂的和,得到所有不包含的数目。顺便在这里推荐一位大牛写的文章http://www.matrix67.com/blog/archives/276, 会更好的理解矩阵。
2)求出前L次幂的和:这个算法之前没有接触过,这次学习了。因为L最大为2^31,依次用快速幂求出每一项的值,复杂度nlgn, 肯定会Tle, 想了好久也没想出好的方法,只好去参考别人的代码,看后,一种感觉,基础不扎实,利用二分的思想,设和为S(n),
n为偶数时:S(n) = S(n/2)*pow(q,n/2)+S(n/2);
n为奇数时:S(n) = S(n/2)*pow(q,n/2)+S(n/2)+pow(q,n/2)。
根据代码,会更好理解此算法。
有了以上两点,这道题基本上也就没什么关键点了。
有一下几点注意:
(1)题目要求,结果模上2^64,只要结果输出格式为%I64u, 原因我也不知道为什么。。。 (尴尬==)
(2)由于在建立矩阵的时候,要用无符号64位的二维数组,题目描述n<6,我没多想,定义了map[10][10],可以运行,但是每次输入,都会栈溢出,纠结了两天,各种改才发现是申请的内存过大。但是,之前遇到内存过大而爆栈是,都不会运行的,不知道这次为什么。。 总之,此题要严格控制内存。
代码如有,有些长:
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
typedef unsigned __int64 u64;
const int maxn = 6; // 最多词根数
const int maxl = 6; // 词根最长长度
const int kind = 26; // 字符集字符数
typedef struct MATRIX{
u64 map[maxn*maxl][maxn*maxl];
}MAT;
struct node {
node* next[kind];
node* fail;
int id, flag;
};
node* root;
node* q[maxn*maxl];
MAT I, mat;
int ID = 0;
node* GetNode() {
node* p;
p = new node;
p->id = ID++;
p->flag = 0;
p->fail = NULL;
memset(p->next, NULL, sizeof(p->next));
return p;
}
void Insert(char word[]) { // Build Trie
int len = strlen(word);
node* p = root;
for (int i = 0; i < len; i++) {
int pos = word[i] - 'a';
if (p->next[pos] == NULL) {
p->next[pos] = GetNode();
}
p = p->next[pos];
}
p->flag = 1;
return ;
}
void BuildACAutomaton() {
int front, rear;
node* p;
front = rear = 0;
q[rear++] = root;
while (front < rear) {
p = q[front++];
for (int i = 0; i < kind; i++) {
if (p->next[i] != NULL) {
q[rear++] = p->next[i];
if (p == root) {
p->next[i]->fail = root;
} else {
node* tmp = p->fail;
while (tmp != NULL) {
if (tmp->next[i] != NULL) {
p->next[i]->fail = tmp->next[i];
if (tmp->next[i]->flag) { // 后缀节点为危险节点,当节点同为危险节点
p->next[i]->flag = 1;
}
break;
}
tmp = tmp->fail;
}
if (tmp == NULL) {
p->next[i]->fail = root;
}
}
}
}
}
return ;
}
void BuildGraph() {
node* p;
int front, rear;
memset(mat.map, 0, sizeof(mat.map));
front = rear = 0;
q[rear++] = root;
while (front < rear) {
p = q[front++];
if (p->flag) continue;
for (int i = 0; i < kind; i++) {
if (p->next[i] != NULL) {
if (p->next[i]->flag) continue;
q[rear++] = p->next[i];
mat.map[p->id][p->next[i]->id]++;
} else {
if (p == root) {
p->next[i] = p;
mat.map[0][0]++;
} else {
node* tmp = p->fail;
p->next[i] = tmp->next[i];
if (!tmp->next[i]->flag) mat.map[p->id][p->next[i]->id]++;
}
}
}
}
return ;
}
void InitI() {
memset(I.map, 0, sizeof(I.map));
for (int i = 0; i < ID; i++) {
I.map[i][i] = 1;
}
return ;
}
MAT MatrixAdd(MAT &a, MAT &b) {
MAT c;
for (int i = 0; i < ID; i++) {
for (int j = 0; j < ID; j++) {
c.map[i][j] = a.map[i][j] + b.map[i][j];
}
}
return c;
}
MAT MatrixMul(MAT &a, MAT &b) {
MAT c;
for (int i = 0; i < ID; i++) {
for (int j = 0; j < ID; j++) {
c.map[i][j] = 0;
for (int k = 0; k < ID; k++) {
c.map[i][j] += a.map[i][k] * b.map[k][j];
}
}
}
return c;
}
MAT MatrixSum(MAT m0, int t) {
MAT m1, m2;
int n = 0, bit[65] = {0};
m1 = m2 = m0;
while (t) {
bit[n++] = (t & 1);
t >>= 1;
}
for (int i = n-2; i >= 0; i--) {
m1 = MatrixMul(m1, MatrixAdd(m2, I));
m2 = MatrixMul(m2, m2);
if (bit[i]) {
m2 = MatrixMul(m2, m0);
m1 = MatrixAdd(m1, m2);
}
}
return m1;
}
int main () {
int n, m;
while (scanf("%d%d", &n, &m) != EOF) {
ID = 0;
char word[maxl];
root = new node;
root = GetNode();
for (int i = 0; i < n; i++) {
scanf("%s", word);
Insert(word);
}
BuildACAutomaton();
BuildGraph();
InitI();
MAT tmp;
memset(tmp.map, 0, sizeof(tmp.map));
tmp.map[0][0] = 26;
tmp = MatrixSum(tmp, m);
mat = MatrixSum(mat, m);
u64 ans = tmp.map[0][0];
for (int i = 0; i < ID; i++) {
ans -= mat.map[0][i];
}
printf("%I64u\n", ans);
}
return 0;
}